Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion scripts/generate_tiny_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from huggingface_hub import HfApi, ModelCard
from peft import LoraConfig, get_peft_model
from torch import nn
from transformers import (
AutoConfig,
Expand Down Expand Up @@ -112,8 +113,9 @@ def push_to_hub(model, tokenizer, generation_config, prefix=None, suffix=None, f
print(f"Model {repo_id} already exists, skipping")
else:
model.push_to_hub(repo_id)
tokenizer.push_to_hub(repo_id)
model_card.push_to_hub(repo_id)
if tokenizer is not None:
tokenizer.push_to_hub(repo_id)
if generation_config is not None:
generation_config.push_to_hub(repo_id)

Expand Down Expand Up @@ -380,3 +382,15 @@ def init_weights_tiny_model(model):
config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs)
model = model_class(config).to(dtype=dtype)
push_to_hub(model, processor, generation_config, "tiny")

# PEFT models
model = Qwen3ForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM", dtype="auto")
model = get_peft_model(model, LoraConfig())
generation_config = GenerationConfig.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM")
push_to_hub(model, None, None, "tiny")

# Same model, but different weights
model = Qwen3ForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM", dtype="auto")
model = get_peft_model(model, LoraConfig())
generation_config = GenerationConfig.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM")
push_to_hub(model, None, None, "tiny", "2")
56 changes: 55 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,67 @@
split_pixel_values_by_grid,
split_tensor_dict,
unsplit_pixel_values_by_grid,
use_adapter,
)

from .testing_utils import TrlTestCase, require_peft, require_rich


if is_peft_available():
from peft import LoraConfig
from peft import AutoPeftModelForCausalLM, LoraConfig


@require_peft
class TestUseAdapter(TrlTestCase):
def test_disables_on_none(self):
model = AutoPeftModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-PeftModel", adapter_name="my_adapter"
)
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
with model.disable_adapter():
expected = model(input_ids).logits

with use_adapter(model, None):
output = model(input_ids).logits

assert torch.equal(output, expected)

def test_restores_previous_adapter(self):
model = AutoPeftModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-PeftModel", adapter_name="my_adapter"
)
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
expected = model(input_ids).logits
with use_adapter(model, "my_adapter"):
pass
output = model(input_ids).logits
assert torch.equal(output, expected)

with use_adapter(model, None):
pass
output = model(input_ids).logits
assert torch.equal(output, expected)

def test_with_multiple_adapters(self):
model = AutoPeftModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-PeftModel", adapter_name="my_adapter_1"
)
model.load_adapter("trl-internal-testing/tiny-PeftModel-2", "my_adapter_2")
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])

model.set_adapter("my_adapter_1") # should be a no-op, but let's keep it for clarity
expected_1 = model(input_ids).logits
model.set_adapter("my_adapter_2")
expected_2 = model(input_ids).logits

with use_adapter(model, "my_adapter_1"):
output_1 = model(input_ids).logits

with use_adapter(model, "my_adapter_2"):
output_2 = model(input_ids).logits

assert torch.equal(output_1, expected_1)
assert torch.equal(output_2, expected_2)


class TestPad(TrlTestCase):
Expand Down
26 changes: 21 additions & 5 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
split_tensor_dict,
start_event_loop_in_daemon,
unsplit_pixel_values_by_grid,
use_adapter,
)


Expand Down Expand Up @@ -155,7 +156,7 @@ class GRPOTrainer(BaseTrainer):
```
Args:
model (`str | PreTrainedModel`):
model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
Expand All @@ -164,6 +165,7 @@ class GRPOTrainer(BaseTrainer):
using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model
config) with the keyword arguments in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
- A [`~peft.PeftModel`] object. Only causal language models are supported.
reward_funcs (`RewardFunc | list[RewardFunc]`):
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
functions with the prompts and completions and sum the rewards. Can be either:
Expand Down Expand Up @@ -258,7 +260,7 @@ class GRPOTrainer(BaseTrainer):

def __init__(
self,
model: str | PreTrainedModel,
model: "str | PreTrainedModel | PeftModel",
Copy link
Collaborator

@edbeeching edbeeching Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model: "str | PreTrainedModel | PeftModel",
model: str | PreTrainedModel | "PeftModel",

reward_funcs: RewardFunc | list[RewardFunc],
args: GRPOConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
Expand Down Expand Up @@ -320,20 +322,30 @@ def __init__(
self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id

if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
if is_peft_available() and is_peft_model(model) and peft_config is not None:
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge "
"and unload the existing adapter, save the resulting base model, and then pass that base model along "
"with the new `peft_config` to the trainer."
)

if is_peft_available() and is_peft_model(model) and self.args.beta != 0.0:
# If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy
# of the "default" adapter, so that we can use it as the reference model during GRPO training.
model.add_adapter("ref", model.peft_config["default"])
for name, param in model.named_parameters():
if ".default." in name:
ref_name = name.replace(".default.", ".ref.")
ref_param = model.get_parameter(ref_name)
ref_param.data.copy_(param.data)

# Create PEFT model
if peft_config is not None:
model = get_peft_model(model, peft_config)

# When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
# handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing:
if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing:
model.enable_input_require_grads()

# When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
Expand Down Expand Up @@ -1964,7 +1976,11 @@ def _generate_and_score_completions(
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
# When training a PEFT adapter, how we obtain the reference depends on the setup:
# - New adapter: disabling adapters yields the base model.
# - Re-training an existing adapter: an initial copy is loaded under the name "ref".
model = self.accelerator.unwrap_model(self.model)
with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None):
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
prompt_completion_ids,
Expand Down
10 changes: 6 additions & 4 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.nn as nn
from accelerate import PartialState
from accelerate.logging import get_logger
from accelerate.utils import is_peft_model
from datasets import Dataset, IterableDataset
from transformers import (
AutoModelForSequenceClassification,
Expand Down Expand Up @@ -195,7 +196,7 @@ class RewardTrainer(BaseTrainer):
```
Args:
model (`str | PreTrainedModel`):
model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
Expand All @@ -204,6 +205,7 @@ class RewardTrainer(BaseTrainer):
using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in
`args.model_init_kwargs`.
- A sequence classification [`~transformers.PreTrainedModel`] object.
- A sequence classification [`~peft.PeftModel`] object.
args ([`RewardConfig`], *optional*):
Configuration for this trainer. If `None`, a default configuration is used.
data_collator ([`~transformers.DataCollator`], *optional*):
Expand Down Expand Up @@ -266,7 +268,7 @@ class RewardTrainer(BaseTrainer):

def __init__(
self,
model: str | PreTrainedModel,
model: "str | PreTrainedModel | PeftModel",
args: RewardConfig | None = None,
data_collator: DataCollator | None = None,
train_dataset: Dataset | IterableDataset | None = None,
Expand Down Expand Up @@ -352,7 +354,7 @@ def __init__(
else:
peft_config.modules_to_save.append("lm_head")

if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
if is_peft_available() and is_peft_model(model) and peft_config is not None:
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge "
"and unload the existing adapter, save the resulting base model, and then pass that base model along "
Expand All @@ -365,7 +367,7 @@ def __init__(

# When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
# handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing:
if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing:
model.enable_input_require_grads()

# When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
Expand Down
25 changes: 20 additions & 5 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
split_tensor_dict,
start_event_loop_in_daemon,
unsplit_pixel_values_by_grid,
use_adapter,
)


Expand Down Expand Up @@ -142,7 +143,7 @@ class RLOOTrainer(BaseTrainer):
```
Args:
model (`str | PreTrainedModel`):
model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
Expand All @@ -151,6 +152,7 @@ class RLOOTrainer(BaseTrainer):
using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model
config) with the keyword arguments in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
- A [`~peft.PeftModel`] object. Only causal language models are supported.
reward_funcs (`RewardFunc | list[RewardFunc]`):
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
functions with the prompts and completions and sum the rewards. Can be either:
Expand Down Expand Up @@ -235,7 +237,7 @@ class RLOOTrainer(BaseTrainer):

def __init__(
self,
model: str | PreTrainedModel,
model: "str | PreTrainedModel | PeftModel",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model: "str | PreTrainedModel | PeftModel",
model: str | PreTrainedModel | "PeftModel",

reward_funcs: RewardFunc | list[RewardFunc],
args: RLOOConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
Expand Down Expand Up @@ -295,20 +297,29 @@ def __init__(
self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id

if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
if is_peft_available() and is_peft_model(model) and peft_config is not None:
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge "
"and unload the existing adapter, save the resulting base model, and then pass that base model along "
"with the new `peft_config` to the trainer."
)
if is_peft_available() and is_peft_model(model):
# If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy
# of the "default" adapter, so that we can use it as the reference model during the training.
model.add_adapter("ref", model.peft_config["default"])
for name, param in model.named_parameters():
if ".default." in name:
ref_name = name.replace(".default.", ".ref.")
ref_param = model.get_parameter(ref_name)
ref_param.data.copy_(param.data)

# Create PEFT model
if peft_config is not None:
model = get_peft_model(model, peft_config)

# When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
# handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing:
if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing:
model.enable_input_require_grads()

# When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
Expand Down Expand Up @@ -1423,7 +1434,11 @@ def _generate_and_score_completions(
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
# When training a PEFT adapter, how we obtain the reference depends on the setup:
# - New adapter: disabling adapters yields the base model.
# - Re-training an existing adapter: an initial copy is loaded under the name "ref".
model = self.accelerator.unwrap_model(self.model)
with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None):
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
prompt_completion_ids,
Expand Down
12 changes: 7 additions & 5 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import transformers
from accelerate import PartialState
from accelerate.logging import get_logger
from accelerate.utils import is_peft_model
from datasets import Dataset, IterableDataset
from packaging.version import Version
from transformers import (
Expand Down Expand Up @@ -503,7 +504,7 @@ class SFTTrainer(BaseTrainer):
```
Args:
model (`str | PreTrainedModel`):
model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
Expand All @@ -512,6 +513,7 @@ class SFTTrainer(BaseTrainer):
using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model
config) with the keyword arguments in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
- A [`~peft.PeftModel`] object. Only causal language models are supported.
If you're training a model with an MoE architecture and want to include the load balancing/auxiliary loss
as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`.
args ([`SFTConfig`], *optional*):
Expand Down Expand Up @@ -580,7 +582,7 @@ class SFTTrainer(BaseTrainer):

def __init__(
self,
model: str | PreTrainedModel,
model: "str | PreTrainedModel | PeftModel",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model: "str | PreTrainedModel | PeftModel",
model: str | PreTrainedModel | "PeftModel",

args: SFTConfig | TrainingArguments | None = None,
data_collator: DataCollator | None = None,
train_dataset: Dataset | IterableDataset | None = None,
Expand Down Expand Up @@ -699,7 +701,7 @@ def __init__(
else:
peft_config.modules_to_save.append("lm_head")

if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
if is_peft_available() and is_peft_model(model) and peft_config is not None:
raise ValueError(
"You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge "
"and unload the existing adapter, save the resulting base model, and then pass that base model along "
Expand All @@ -712,7 +714,7 @@ def __init__(

# When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
# handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing:
if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing:
model.enable_input_require_grads()

# When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
Expand All @@ -728,7 +730,7 @@ def __init__(
# In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the
# input. We store the number of these tokens so we can account for them correctly when calculating accuracy.
self.num_virtual_tokens = 0
if is_peft_available() and isinstance(model, PeftModel):
if is_peft_available() and is_peft_model(model):
if model.active_adapter in model.peft_config:
peft_model_config = model.peft_config[model.active_adapter]
self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0)
Expand Down
Loading
Loading