Skip to content

Commit a6968b7

Browse files
lugimzzzroot
and
root
authored
[llm]support lora merge (#7733)
* support lora merge * add finetune * support bf16 and lora merge * update * add lora merge * version adatpt * update * fix --------- Co-authored-by: root <[email protected]>
1 parent 1982091 commit a6968b7

File tree

14 files changed

+377
-81
lines changed

14 files changed

+377
-81
lines changed

llm/docs/finetune.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,16 @@ python merge_tp_and_pp_params.py \
215215
为了后续的**压缩****静态图推理**方便,我们提供LoRA参数合并脚本,可以将LoRA参数合并到主干模型并保存相应的权重。
216216
```
217217
python merge_lora_params.py \
218-
--model_name_or_path meta-llama/Llama-2-7b-chat \
219-
--lora_path ./checkpoints/llama_lora_ckpts
218+
--lora_path ./checkpoints/llama_lora_ckpts \
219+
--merge_lora_model_path ./checkpoints/llama_lora_merge \
220+
--device "gpu" \
221+
--low_gpu_mem True
220222
```
223+
221224
<summary>&emsp; 脚本参数介绍</summary><div>
222225

223-
- `model_name_or_path`: 必须,预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。
224226
- `lora_path`: LoRA参数和配置路径,对LoRA参数进行初始化,默认为None。
225227
- `merge_model_path`: 必须,合并参数后保存路径,默认为None。
226228
- `device`: 运行环境,默认为gpu。
229+
- `low_gpu_mem`:降低合参时候所需显存,默认为False。如果合参时显存不足,建议开启
227230
</div>

llm/finetune_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,6 @@ def main():
162162
else:
163163
# NOTE(gongenlei): new add autotuner_benchmark
164164
model = AutoModelForCausalLM.from_config(model_config, dtype=dtype)
165-
166165
if training_args.do_train and model_args.neftune:
167166
# Inspired by https://github.com/neelsjain/NEFTune
168167
if hasattr(model, "get_input_embeddings"):
@@ -418,6 +417,7 @@ def neft_post_hook(module, input, output):
418417
tensor_parallel_degree=training_args.tensor_parallel_degree,
419418
dtype=dtype,
420419
do_qat=quant_args.do_qat,
420+
base_model_name_or_path=model_args.model_name_or_path,
421421
)
422422
model = LoRAModel(model, lora_config)
423423
else:

llm/llama/qlora_argument.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"model_name_or_path": "facebook/llama-7b",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/llama_lora_ckpts",
5+
"per_device_train_batch_size": 4,
6+
"gradient_accumulation_steps": 4,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps":16,
9+
"num_train_epochs": 3,
10+
"learning_rate": 3e-04,
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 1024,
16+
"max_length": 2048,
17+
"fp16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 1,
27+
"tensor_parallel_degree": 1,
28+
"pipeline_parallel_degree": 1,
29+
"lora": true,
30+
"zero_padding": false,
31+
"use_flash_attention": false,
32+
"weight_quantize_algo": "nf4"
33+
}

llm/llama/wint8_lora_argument.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"model_name_or_path": "facebook/llama-7b",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/llama_lora_ckpts",
5+
"per_device_train_batch_size": 4,
6+
"gradient_accumulation_steps": 4,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps":16,
9+
"num_train_epochs": 3,
10+
"learning_rate": 3e-04,
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 1024,
16+
"max_length": 2048,
17+
"fp16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 1,
27+
"tensor_parallel_degree": 1,
28+
"pipeline_parallel_degree": 1,
29+
"lora": true,
30+
"zero_padding": false,
31+
"use_flash_attention": false,
32+
"weight_quantize_algo": "weight_only_int8"
33+
}

llm/merge_lora_params.py

Lines changed: 109 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,45 +12,135 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import argparse
15+
import copy
16+
import os
1517

1618
import paddle
1719

1820
from paddlenlp.peft import LoRAConfig, LoRAModel
19-
from paddlenlp.transformers import AutoModelForCausalLM
21+
22+
try:
23+
from paddle.nn.quant import weight_dequantize, weight_quantize
24+
except:
25+
weight_dequantize = None
26+
weight_quantize = None
27+
try:
28+
from paddlenlp.quantization.qlora import qlora_weight_quantize_dequantize
29+
except:
30+
qlora_weight_quantize_dequantize = None
31+
32+
from paddlenlp.quantization.quantization_config import QuantizationConfig
33+
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
34+
from paddlenlp.transformers.utils import device_guard
35+
from paddlenlp.utils.env import CONFIG_NAME
2036

2137

2238
def parse_arguments():
2339
parser = argparse.ArgumentParser()
24-
parser.add_argument("--model_name_or_path", default=None, required=True, help="The directory of pretrained model.")
40+
parser.add_argument("--model_name_or_path", default=None, help="The directory of pretrained model.")
2541
parser.add_argument(
2642
"--lora_path", default=None, required=True, help="The directory of LoRA parameters. Default to None"
2743
)
28-
parser.add_argument("--merge_model_path", default=None, help="The directory of merged parameters. Default to None")
44+
parser.add_argument(
45+
"--merge_lora_model_path",
46+
default=None,
47+
required=True,
48+
help="The directory of merged parameters. Default to None",
49+
)
2950
parser.add_argument("--device", type=str, default="gpu", help="Device")
51+
parser.add_argument(
52+
"--low_gpu_mem", type=bool, default=False, help="Whether to use low gpu memory. Default to False"
53+
)
3054
return parser.parse_args()
3155

3256

57+
def weight_process(name, quant_config, lora_config, state_dict):
58+
weight = state_dict.pop(name + ".weight").cuda()
59+
if quant_config.weight_quantize_algo is None:
60+
pass
61+
elif quant_config.weight_quantize_algo in ["nf4", "fp4"]:
62+
weight = qlora_weight_quantize_dequantize(
63+
weight,
64+
quant_algo=quant_config.weight_quantize_algo,
65+
double_quant=quant_config.weight_double_quant,
66+
block_size=quant_config.weight_blocksize,
67+
double_quant_block_size=quant_config.weight_double_quant_block_size,
68+
)
69+
elif quant_config.weight_quantize_algo in ["weight_only_int8"]:
70+
out, scale = weight_quantize(weight, algo=quant_config.weight_quantize_algo)
71+
weight = weight_dequantize(out, scale)
72+
else:
73+
raise ValueError(f"quant_config.weight_quantize_algo {quant_config.weight_quantize_algo} is not supported.")
74+
lora_A = state_dict.pop(name + ".lora_A").cuda()
75+
lora_B = state_dict.pop(name + ".lora_B").cuda()
76+
scaling = lora_config.lora_alpha / lora_config.r
77+
state_dict[name + ".weight"] = (weight + lora_A @ lora_B * scaling).cpu()
78+
79+
3380
def merge():
3481
args = parse_arguments()
3582
paddle.set_device(args.device)
83+
3684
lora_config = LoRAConfig.from_pretrained(args.lora_path)
37-
dtype = lora_config.dtype
38-
lora_config.merge_weights = True
85+
if lora_config.base_model_name_or_path is None:
86+
if args.model_name_or_path is not None:
87+
raise ValueError("We can not find a valid model_name_or_path.")
88+
else:
89+
lora_config.base_model_name_or_path = args.model_name_or_path
3990

40-
model = AutoModelForCausalLM.from_pretrained(
41-
args.model_name_or_path,
42-
dtype=dtype,
43-
)
44-
model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config)
45-
model.eval()
46-
if args.merge_model_path is None:
47-
args.merge_model_path = args.lora_path
48-
49-
model_state_dict = model.model.state_dict()
50-
for key in list(model_state_dict):
51-
if "lora" in key:
52-
del model_state_dict[key]
53-
model.model.save_pretrained(args.merge_model_path, state_dict=model_state_dict)
91+
if os.path.isfile(os.path.join(args.lora_path, CONFIG_NAME)):
92+
config = AutoConfig.from_pretrained(args.lora_path)
93+
elif args.model_name_or_path is not None:
94+
config = AutoConfig.from_pretrained(args.model_name_or_path)
95+
else:
96+
raise ValueError(
97+
f"We can not find config.json in lora_path: {args.lora_path} or find a valid model_name_or_path."
98+
)
99+
config.dtype = lora_config.dtype
100+
if (
101+
lora_config.dtype == "bfloat16" or config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
102+
) and args.device == "cpu":
103+
raise ValueError("We can not apply bfloat16 or nf4/fp4 lora merge on cpu.")
104+
105+
if args.low_gpu_mem and args.device == "gpu":
106+
quant_config = copy.deepcopy(config.quantization_config)
107+
config.quantization_config = QuantizationConfig()
108+
lora_config.merge_weights = False
109+
with device_guard():
110+
model = AutoModelForCausalLM.from_pretrained(
111+
lora_config.base_model_name_or_path,
112+
config=config,
113+
low_cpu_mem_usage=True,
114+
)
115+
model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config)
116+
model.eval()
117+
model_state_dict = model.model.state_dict()
118+
lora_name_list = []
119+
for key in model_state_dict.keys():
120+
if "lora_A" in key:
121+
lora_name_list.append(key[:-7])
122+
for name in lora_name_list:
123+
weight_process(name, quant_config, lora_config, model_state_dict)
124+
else:
125+
model = AutoModelForCausalLM.from_pretrained(
126+
lora_config.base_model_name_or_path,
127+
config=config,
128+
low_cpu_mem_usage=True,
129+
)
130+
lora_config.merge_weights = True
131+
model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config)
132+
model.eval()
133+
model_state_dict = model.model.state_dict()
134+
for key in list(model_state_dict):
135+
if "lora" in key:
136+
del model_state_dict[key]
137+
if "quant" in key:
138+
del model_state_dict[key]
139+
model.model.config.quantization_config = QuantizationConfig()
140+
model.model.save_pretrained(args.merge_lora_model_path, state_dict=model_state_dict)
141+
142+
tokenizer = AutoTokenizer.from_pretrained(lora_config.base_model_name_or_path)
143+
tokenizer.save_pretrained(args.merge_lora_model_path)
54144

55145

56146
if __name__ == "__main__":

paddlenlp/peft/lora/lora_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ class LoRAConfig:
7272
},
7373
)
7474
do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"})
75+
base_model_name_or_path: Optional[str] = field(
76+
default=None, metadata={"help": "The name of the base model to use."}
77+
)
7578

7679
@property
7780
def __dict__(self):

paddlenlp/peft/lora/lora_model.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import math
1617
import os
1718
import re
@@ -226,14 +227,16 @@ def _convert_tensor_parallel(self, lora_state_dict):
226227
return lora_state_dict
227228

228229
def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False, **kwargs):
230+
save_model_config = kwargs.get("save_model_config", True)
231+
229232
if self.is_pipelinemodel:
230233
self.model._single_to_pp_mapping = None
231-
if self.quantized and merge_tensor_parallel and self.model.config.tensor_parallel_degree > 1:
234+
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1:
232235
merge_tensor_parallel = False
233236
logger.warning(
234237
"Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
235238
)
236-
if self.is_pipelinemodel and merge_tensor_parallel and self.model.config.tensor_parallel_degree > 1:
239+
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1:
237240
merge_tensor_parallel = False
238241
logger.warning(
239242
"Pipeline parallism does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
@@ -247,18 +250,20 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
247250
), f"Saving directory ({save_directory}) should be a directory, not a file"
248251
os.makedirs(save_directory, exist_ok=True)
249252

250-
if merge_tensor_parallel and self.model.config.tensor_parallel_degree > 1:
253+
lora_config_to_save = LoRAConfig(**self.lora_config.to_dict())
254+
255+
if merge_tensor_parallel and lora_config_to_save.tensor_parallel_degree > 1:
251256
trainable_state_dict = self.get_trainable_state_dict()
252257
trainable_state_dict = self._merge_trainable_tensor_parallel(trainable_state_dict)
253258
if not is_main_process:
254259
logger.info("Saving with merge_tensor_parallel, tensor_parallel_rank > 0 don't need save")
255260
return
256261
if variant is not None and "tp" in variant:
257262
variant = "_".join([x for x in variant.split("_") if "tp" not in x])
258-
self.lora_config.tensor_parallel_degree = -1
263+
lora_config_to_save.tensor_parallel_degree = -1
259264
else:
260265
trainable_state_dict = self.get_trainable_state_dict()
261-
if self.model.config.tensor_parallel_degree > 1:
266+
if lora_config_to_save.tensor_parallel_degree > 1:
262267
if variant is None:
263268
variant = weight_name_suffix()
264269

@@ -269,8 +274,12 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
269274

270275
# save lora config
271276
if is_main_process:
272-
self.lora_config.save_pretrained(save_directory)
273-
self.lora_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree
277+
lora_config_to_save.save_pretrained(save_directory)
278+
if save_model_config:
279+
model_config_to_save = copy.deepcopy(self.model.config)
280+
if merge_tensor_parallel:
281+
model_config_to_save.tensor_parallel_degree = -1
282+
model_config_to_save.save_pretrained(save_directory)
274283

275284
def _find_and_replace_module(self, model, module_name, lora_config, enable_lora):
276285
parent_module = model
@@ -366,6 +375,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
366375
r=lora_config.r,
367376
lora_alpha=lora_config.lora_alpha,
368377
lora_dropout=lora_config.lora_dropout,
378+
merge_weights=lora_config.merge_weights,
369379
)
370380
self.quantized = True
371381
elif ColumnParallelQuantizationLinear is not None and isinstance(module, ColumnParallelQuantizationLinear):

0 commit comments

Comments
 (0)