generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Improve PEFT integration #4723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Improve PEFT integration #4723
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
e421a04
Disallow PeftModel + peft_config in trainers
qgallouedec 1d603ff
remove tests
qgallouedec f4019fa
remove old comments
qgallouedec 6a71878
Merge branch 'main' into disallow-peft-model-with-config
qgallouedec 7406c79
Better peft integration
qgallouedec e7de51b
type hint
qgallouedec f16ba77
tiny models
qgallouedec 6d80078
Remove force option from push_to_hub calls in generate_tiny_models.py
qgallouedec f38410d
Merge branch 'main' into better-peft-integration
qgallouedec d1c3542
Merge branch 'main' into better-peft-integration
qgallouedec 4714d0e
Don't push tokenizer with adapter
qgallouedec 6cbab27
Merge branch 'main' into better-peft-integration
qgallouedec f648da9
Merge branch 'main' into better-peft-integration
qgallouedec File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -87,6 +87,7 @@ | |||||
| split_tensor_dict, | ||||||
| start_event_loop_in_daemon, | ||||||
| unsplit_pixel_values_by_grid, | ||||||
| use_adapter, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -142,7 +143,7 @@ class RLOOTrainer(BaseTrainer): | |||||
| ``` | ||||||
| Args: | ||||||
| model (`str | PreTrainedModel`): | ||||||
| model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): | ||||||
| Model to be trained. Can be either: | ||||||
| - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a | ||||||
|
|
@@ -151,6 +152,7 @@ class RLOOTrainer(BaseTrainer): | |||||
| using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model | ||||||
| config) with the keyword arguments in `args.model_init_kwargs`. | ||||||
| - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. | ||||||
| - A [`~peft.PeftModel`] object. Only causal language models are supported. | ||||||
| reward_funcs (`RewardFunc | list[RewardFunc]`): | ||||||
| Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward | ||||||
| functions with the prompts and completions and sum the rewards. Can be either: | ||||||
|
|
@@ -235,7 +237,7 @@ class RLOOTrainer(BaseTrainer): | |||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| model: str | PreTrainedModel, | ||||||
| model: "str | PreTrainedModel | PeftModel", | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| reward_funcs: RewardFunc | list[RewardFunc], | ||||||
| args: RLOOConfig | None = None, | ||||||
| train_dataset: Dataset | IterableDataset | None = None, | ||||||
|
|
@@ -295,20 +297,29 @@ def __init__( | |||||
| self.pad_token_id = tokenizer.pad_token_id | ||||||
| self.eos_token_id = tokenizer.eos_token_id | ||||||
|
|
||||||
| if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None: | ||||||
| if is_peft_available() and is_peft_model(model) and peft_config is not None: | ||||||
| raise ValueError( | ||||||
| "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " | ||||||
| "and unload the existing adapter, save the resulting base model, and then pass that base model along " | ||||||
| "with the new `peft_config` to the trainer." | ||||||
| ) | ||||||
| if is_peft_available() and is_peft_model(model): | ||||||
| # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy | ||||||
| # of the "default" adapter, so that we can use it as the reference model during the training. | ||||||
| model.add_adapter("ref", model.peft_config["default"]) | ||||||
| for name, param in model.named_parameters(): | ||||||
| if ".default." in name: | ||||||
| ref_name = name.replace(".default.", ".ref.") | ||||||
| ref_param = model.get_parameter(ref_name) | ||||||
| ref_param.data.copy_(param.data) | ||||||
|
|
||||||
| # Create PEFT model | ||||||
| if peft_config is not None: | ||||||
| model = get_peft_model(model, peft_config) | ||||||
|
|
||||||
| # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally | ||||||
| # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 | ||||||
| if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: | ||||||
| if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: | ||||||
| model.enable_input_require_grads() | ||||||
|
|
||||||
| # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the | ||||||
|
|
@@ -1423,7 +1434,11 @@ def _generate_and_score_completions( | |||||
| **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes | ||||||
| ) | ||||||
| else: | ||||||
| with self.accelerator.unwrap_model(self.model).disable_adapter(): | ||||||
| # When training a PEFT adapter, how we obtain the reference depends on the setup: | ||||||
| # - New adapter: disabling adapters yields the base model. | ||||||
| # - Re-training an existing adapter: an initial copy is loaded under the name "ref". | ||||||
| model = self.accelerator.unwrap_model(self.model) | ||||||
| with use_adapter(model, adapter_name="ref" if "ref" in model.peft_config else None): | ||||||
| ref_per_token_logps, _ = self._get_per_token_logps_and_entropies( | ||||||
| self.model, | ||||||
| prompt_completion_ids, | ||||||
|
|
||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |||||
| import transformers | ||||||
| from accelerate import PartialState | ||||||
| from accelerate.logging import get_logger | ||||||
| from accelerate.utils import is_peft_model | ||||||
| from datasets import Dataset, IterableDataset | ||||||
| from packaging.version import Version | ||||||
| from transformers import ( | ||||||
|
|
@@ -503,7 +504,7 @@ class SFTTrainer(BaseTrainer): | |||||
| ``` | ||||||
| Args: | ||||||
| model (`str | PreTrainedModel`): | ||||||
| model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): | ||||||
| Model to be trained. Can be either: | ||||||
| - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a | ||||||
|
|
@@ -512,6 +513,7 @@ class SFTTrainer(BaseTrainer): | |||||
| using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model | ||||||
| config) with the keyword arguments in `args.model_init_kwargs`. | ||||||
| - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. | ||||||
| - A [`~peft.PeftModel`] object. Only causal language models are supported. | ||||||
| If you're training a model with an MoE architecture and want to include the load balancing/auxiliary loss | ||||||
| as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`. | ||||||
| args ([`SFTConfig`], *optional*): | ||||||
|
|
@@ -580,7 +582,7 @@ class SFTTrainer(BaseTrainer): | |||||
|
|
||||||
| def __init__( | ||||||
| self, | ||||||
| model: str | PreTrainedModel, | ||||||
| model: "str | PreTrainedModel | PeftModel", | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| args: SFTConfig | TrainingArguments | None = None, | ||||||
| data_collator: DataCollator | None = None, | ||||||
| train_dataset: Dataset | IterableDataset | None = None, | ||||||
|
|
@@ -699,7 +701,7 @@ def __init__( | |||||
| else: | ||||||
| peft_config.modules_to_save.append("lm_head") | ||||||
|
|
||||||
| if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None: | ||||||
| if is_peft_available() and is_peft_model(model) and peft_config is not None: | ||||||
| raise ValueError( | ||||||
| "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " | ||||||
| "and unload the existing adapter, save the resulting base model, and then pass that base model along " | ||||||
|
|
@@ -712,7 +714,7 @@ def __init__( | |||||
|
|
||||||
| # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally | ||||||
| # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 | ||||||
| if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: | ||||||
| if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: | ||||||
| model.enable_input_require_grads() | ||||||
|
|
||||||
| # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the | ||||||
|
|
@@ -728,7 +730,7 @@ def __init__( | |||||
| # In Prompt Tuning a small set of trainable virtual tokens (continuous prompt embeddings) is prepended to the | ||||||
| # input. We store the number of these tokens so we can account for them correctly when calculating accuracy. | ||||||
| self.num_virtual_tokens = 0 | ||||||
| if is_peft_available() and isinstance(model, PeftModel): | ||||||
| if is_peft_available() and is_peft_model(model): | ||||||
| if model.active_adapter in model.peft_config: | ||||||
| peft_model_config = model.peft_config[model.active_adapter] | ||||||
| self.num_virtual_tokens = getattr(peft_model_config, "num_virtual_tokens", 0) | ||||||
|
|
||||||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.