Skip to content

This repository contains the official source code for SALT: Parameter-Efficient Fine-Tuning via Singular Value Adaptation with Low-Rank Transformation [BMVC'2025].

Notifications You must be signed in to change notification settings

BioMedIA-MBZUAI/SALT

Repository files navigation

SALT: Parameter-Efficient Fine-Tuning via Singular Value Adaptation with Low-Rank Transformation

SALT is a novel Parameter-Efficient Fine-Tuning (PEFT) method designed to adapt large-scale foundation models—especially Segment Anything Model (SAM)—to domain-specific tasks such as medical image segmentation. With SALT, you can achieve state-of-the-art segmentation performance while training only a small percentage of parameters.

Methodology


Table of Contents

  1. Key Features
  2. Installation
  3. Datasets
  4. Repository Structure
  5. Usage
  6. Command Examples
  7. Performance Highlights
  8. Citations
  9. License

Key Features

  • 🌟 Hybrid PEFT
    Merges SVD-based parameter updates for top singular values with LoRA for the residual subspace.
  • 📈 Parameter Efficiency
    Yields SOTA performance with as little as 3.9% of total parameters being trainable.
  • 🔬 Robust Medical Adaptation
    Outperforms other PEFT methods (LoRA, S-SAM) by 2–5% in Dice scores across challenging medical datasets.
  • 🔧 Easy Integration
    Provided as a drop-in module for existing SAM pipelines, requiring minimal code changes.

Installation

  1. Clone this repository :cloning:

    git clone https://github.com/YourUsername/SALT.git
    cd SALT
  2. Set up a Conda environment 🐍

    conda env create -f SALT_env.yml
    conda activate SALT

    Note: Verify or modify the packages in SALT_env.yml as required (e.g., PyTorch version, CUDA, etc.).


Datasets

Below are the five medical imaging datasets highlighted in our experiments. Each dataset can be found on Hugging Face:

  • 🟢: ROSE (Retinal OCT Angiography)
  • 🔵: ARCADE (Coronary Artery Segmentation)
  • 🟠: DRIVE (Retinal Vessel Segmentation)
  • 🟡: DIAS (Dynamic Digital Subtraction Angiography)
  • 🔴: Xray-Angio (Occluded Vessel Segmentation)

Please organize each dataset as follows:

dataset_name/
├── images/               # Folder containing image files
├── masks/                # Folder containing segmentation masks
└── data_split.csv        # CSV specifying [imgs, split] columns for train/val/test

Tip: Make sure the file paths in data_split.csv correspond exactly to those in the images/ and masks/ folders.


Repository Structure

Here's a simplified overview of the repository and the purpose of each major file:

AllinonSAM/
├── main.py                # Entry point for training/testing SALT
├── model.py               # Model definitions (Prompt_Adapted_SAM, Prompt_Adapted_SAM2)
├── train.py               # Core training loop and helper functions
├── test.py                # Evaluation and metrics computation
├── datasets/
│   └── data.py            # Custom Dataset (GeneralDataset) and data loader logic
├── utils.py          # Dice, Focal, and other loss functions
├── data_transforms/
│   ├── data_transforms.py # Contains Data_Transform class
├── prompt_adapted_SAM2/
│   ├── .... # Contains the implementation for Prompt Adapted SAM 2 along with LoRA , SVD , SALT layers.
├── prompt_adapted_segment_anything/
│   ├── .... # Contains the implementation for Prompt Adapted SAM along with LoRA , SVD , SALT layers.
├── SALT_env.yml
└── config_data.yml # Contains example data set configuration
└── SALT_config.yml # Contains example model and training configuration
└── ... # Other complementary files for doing model analysis, submitting jobs,...etc
  • main.py
    Parses command-line arguments, loads datasets, sets up models, and invokes training or testing.

  • model.py
    Houses the Prompt_Adapted_SAM and Prompt_Adapted_SAM2 classes, plus SALT/LoRA logic.

  • train.py
    Contains the main training function, optimization loops, checkpoint saving, etc.

  • test.py
    Handles inference and computing evaluation metrics (e.g., Dice, HD95).

  • data/data.py
    Implements the GeneralDataset for image-mask pair loading and augmentation.

  • utils/losses.py
    Provides a variety of segmentation loss functions (Dice, BCE, Focal, WeightedCE).


Usage

1. Configuration Files

You’ll need two primary YAML files:

  1. Data Config (config_data.yaml)
    Defines your dataset root path, transforms, image size, label definitions, etc. Example:

    data_transforms:
      a_min: 0
      a_max: 255
      img_size: 512
      ...
    data:
      root_path: '/path/to/ROSE'
      data_split_csv: '/path/to/ROSE/data_split.csv'
      label_list: [0, 1]
      label_names: ["Background", "Vein"]
      ...
  2. Model Config (SALT_config.yaml)
    Specifies model settings, optimizer, LoRA/SALT ranks, etc. Example:

    sam:
      img_size: 512
      num_classes: 2
      sam_type: "base"
    
    training:
      optimizer: 'adamw'
      lr: 1e-4
      batch_size: 8
      num_epochs: 200
      ...
    ft:
      type: 'svd'    # Could be 'svd', 'lora', or 'salt'
      svd_rank_linear: 0
      svd_rank_conv2d: 0
      r_lora: 4

2. Training

Use the script main.py to train. Key flags include:

  • --data_config: path to YAML for data
  • --model_config: path to YAML for model and training hyperparameters
  • --pretrained_path: path to your SAM checkpoint (if applicable)
  • --save_path: where to store trained weights
  • --training_strategy: one of [salt, svdbiastuning, biastuning, lora]

For example:

python main.py \
  --data_config configs/config_data.yaml \
  --model_config configs/SALT_config.yaml \
  --pretrained_path /path/to/sam_checkpoint.pth \
  --save_path checkpoints/salt_model.pth \
  --training_strategy salt \
  --device cuda:0

Note: You may need to adjust device based on your system (e.g., "cuda:1" or "cpu").


Command Examples

  1. Training with SALT
    python main.py \
      --data_config configs/config_data.yaml \
      --model_config configs/SALT_config.yaml \
      --pretrained_path /path/to/sam_checkpoint.pth \
      --save_path checkpoints/salt_model.pth \
      --training_strategy salt \
      --device cuda:0
  2. Training with LoRA 🎉
    python main.py \
      --data_config configs/config_data.yaml \
      --model_config configs/SALT_config.yaml \
      --pretrained_path /path/to/sam_checkpoint.pth \
      --save_path checkpoints/lora_model.pth \
      --training_strategy lora \
      --device cuda:0

Performance Highlights

Here’s a brief summary of SALT’s performance versus other PEFT methods:

Method Avg. Dice Avg. HD95 Trainable Params
LoRA (rank=256) 0.70 25.94 14.08%
S-SAM 0.71 30.12 0.40%
SALT (Ours) 0.74 23.87 3.90%

🔥: SALT provides the best Dice and HD95 among these PEFT methods, striking a strong balance between accuracy and parameter efficiency.


Citations

If you use this work, please cite:

@misc{elsayed2025saltsingularvalueadaptation,
      title={SALT: Singular Value Adaptation with Low-Rank Transformation}, 
      author={Abdelrahman Elsayed and Sarim Hashmi and Mohammed Elseiagy and Hu Wang and Mohammad Yaqub and Ibrahim Almakky},
      year={2025},
      eprint={2503.16055},
      archivePrefix={arXiv},
      primaryClass={eess.IV},
      url={https://arxiv.org/abs/2503.16055}, 
}

License

This project is released under the MIT License.
You are free to use, modify, and distribute this software; see the LICENSE file for details.

About

This repository contains the official source code for SALT: Parameter-Efficient Fine-Tuning via Singular Value Adaptation with Low-Rank Transformation [BMVC'2025].

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published