|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import argparse
|
| 15 | +import copy |
| 16 | +import os |
15 | 17 |
|
16 | 18 | import paddle
|
17 | 19 |
|
18 | 20 | from paddlenlp.peft import LoRAConfig, LoRAModel
|
19 |
| -from paddlenlp.transformers import AutoModelForCausalLM |
| 21 | + |
| 22 | +try: |
| 23 | + from paddle.nn.quant import weight_dequantize, weight_quantize |
| 24 | +except: |
| 25 | + weight_dequantize = None |
| 26 | + weight_quantize = None |
| 27 | +try: |
| 28 | + from paddlenlp.quantization.qlora import qlora_weight_quantize_dequantize |
| 29 | +except: |
| 30 | + qlora_weight_quantize_dequantize = None |
| 31 | + |
| 32 | +from paddlenlp.quantization.quantization_config import QuantizationConfig |
| 33 | +from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
| 34 | +from paddlenlp.transformers.utils import device_guard |
| 35 | +from paddlenlp.utils.env import CONFIG_NAME |
20 | 36 |
|
21 | 37 |
|
22 | 38 | def parse_arguments():
|
23 | 39 | parser = argparse.ArgumentParser()
|
24 |
| - parser.add_argument("--model_name_or_path", default=None, required=True, help="The directory of pretrained model.") |
| 40 | + parser.add_argument("--model_name_or_path", default=None, help="The directory of pretrained model.") |
25 | 41 | parser.add_argument(
|
26 | 42 | "--lora_path", default=None, required=True, help="The directory of LoRA parameters. Default to None"
|
27 | 43 | )
|
28 |
| - parser.add_argument("--merge_model_path", default=None, help="The directory of merged parameters. Default to None") |
| 44 | + parser.add_argument( |
| 45 | + "--merge_lora_model_path", |
| 46 | + default=None, |
| 47 | + required=True, |
| 48 | + help="The directory of merged parameters. Default to None", |
| 49 | + ) |
29 | 50 | parser.add_argument("--device", type=str, default="gpu", help="Device")
|
| 51 | + parser.add_argument( |
| 52 | + "--low_gpu_mem", type=bool, default=False, help="Whether to use low gpu memory. Default to False" |
| 53 | + ) |
30 | 54 | return parser.parse_args()
|
31 | 55 |
|
32 | 56 |
|
| 57 | +def weight_process(name, quant_config, lora_config, state_dict): |
| 58 | + weight = state_dict.pop(name + ".weight").cuda() |
| 59 | + if quant_config.weight_quantize_algo is None: |
| 60 | + pass |
| 61 | + elif quant_config.weight_quantize_algo in ["nf4", "fp4"]: |
| 62 | + weight = qlora_weight_quantize_dequantize( |
| 63 | + weight, |
| 64 | + quant_algo=quant_config.weight_quantize_algo, |
| 65 | + double_quant=quant_config.weight_double_quant, |
| 66 | + block_size=quant_config.weight_blocksize, |
| 67 | + double_quant_block_size=quant_config.weight_double_quant_block_size, |
| 68 | + ) |
| 69 | + elif quant_config.weight_quantize_algo in ["weight_only_int8"]: |
| 70 | + out, scale = weight_quantize(weight, algo=quant_config.weight_quantize_algo) |
| 71 | + weight = weight_dequantize(out, scale) |
| 72 | + else: |
| 73 | + raise ValueError(f"quant_config.weight_quantize_algo {quant_config.weight_quantize_algo} is not supported.") |
| 74 | + lora_A = state_dict.pop(name + ".lora_A").cuda() |
| 75 | + lora_B = state_dict.pop(name + ".lora_B").cuda() |
| 76 | + scaling = lora_config.lora_alpha / lora_config.r |
| 77 | + state_dict[name + ".weight"] = (weight + lora_A @ lora_B * scaling).cpu() |
| 78 | + |
| 79 | + |
33 | 80 | def merge():
|
34 | 81 | args = parse_arguments()
|
35 | 82 | paddle.set_device(args.device)
|
| 83 | + |
36 | 84 | lora_config = LoRAConfig.from_pretrained(args.lora_path)
|
37 |
| - dtype = lora_config.dtype |
38 |
| - lora_config.merge_weights = True |
| 85 | + if lora_config.base_model_name_or_path is None: |
| 86 | + if args.model_name_or_path is not None: |
| 87 | + raise ValueError("We can not find a valid model_name_or_path.") |
| 88 | + else: |
| 89 | + lora_config.base_model_name_or_path = args.model_name_or_path |
39 | 90 |
|
40 |
| - model = AutoModelForCausalLM.from_pretrained( |
41 |
| - args.model_name_or_path, |
42 |
| - dtype=dtype, |
43 |
| - ) |
44 |
| - model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config) |
45 |
| - model.eval() |
46 |
| - if args.merge_model_path is None: |
47 |
| - args.merge_model_path = args.lora_path |
48 |
| - |
49 |
| - model_state_dict = model.model.state_dict() |
50 |
| - for key in list(model_state_dict): |
51 |
| - if "lora" in key: |
52 |
| - del model_state_dict[key] |
53 |
| - model.model.save_pretrained(args.merge_model_path, state_dict=model_state_dict) |
| 91 | + if os.path.isfile(os.path.join(args.lora_path, CONFIG_NAME)): |
| 92 | + config = AutoConfig.from_pretrained(args.lora_path) |
| 93 | + elif args.model_name_or_path is not None: |
| 94 | + config = AutoConfig.from_pretrained(args.model_name_or_path) |
| 95 | + else: |
| 96 | + raise ValueError( |
| 97 | + f"We can not find config.json in lora_path: {args.lora_path} or find a valid model_name_or_path." |
| 98 | + ) |
| 99 | + config.dtype = lora_config.dtype |
| 100 | + if ( |
| 101 | + lora_config.dtype == "bfloat16" or config.quantization_config.weight_quantize_algo in ["nf4", "fp4"] |
| 102 | + ) and args.device == "cpu": |
| 103 | + raise ValueError("We can not apply bfloat16 or nf4/fp4 lora merge on cpu.") |
| 104 | + |
| 105 | + if args.low_gpu_mem and args.device == "gpu": |
| 106 | + quant_config = copy.deepcopy(config.quantization_config) |
| 107 | + config.quantization_config = QuantizationConfig() |
| 108 | + lora_config.merge_weights = False |
| 109 | + with device_guard(): |
| 110 | + model = AutoModelForCausalLM.from_pretrained( |
| 111 | + lora_config.base_model_name_or_path, |
| 112 | + config=config, |
| 113 | + low_cpu_mem_usage=True, |
| 114 | + ) |
| 115 | + model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config) |
| 116 | + model.eval() |
| 117 | + model_state_dict = model.model.state_dict() |
| 118 | + lora_name_list = [] |
| 119 | + for key in model_state_dict.keys(): |
| 120 | + if "lora_A" in key: |
| 121 | + lora_name_list.append(key[:-7]) |
| 122 | + for name in lora_name_list: |
| 123 | + weight_process(name, quant_config, lora_config, model_state_dict) |
| 124 | + else: |
| 125 | + model = AutoModelForCausalLM.from_pretrained( |
| 126 | + lora_config.base_model_name_or_path, |
| 127 | + config=config, |
| 128 | + low_cpu_mem_usage=True, |
| 129 | + ) |
| 130 | + lora_config.merge_weights = True |
| 131 | + model = LoRAModel.from_pretrained(model=model, lora_path=args.lora_path, lora_config=lora_config) |
| 132 | + model.eval() |
| 133 | + model_state_dict = model.model.state_dict() |
| 134 | + for key in list(model_state_dict): |
| 135 | + if "lora" in key: |
| 136 | + del model_state_dict[key] |
| 137 | + if "quant" in key: |
| 138 | + del model_state_dict[key] |
| 139 | + model.model.config.quantization_config = QuantizationConfig() |
| 140 | + model.model.save_pretrained(args.merge_lora_model_path, state_dict=model_state_dict) |
| 141 | + |
| 142 | + tokenizer = AutoTokenizer.from_pretrained(lora_config.base_model_name_or_path) |
| 143 | + tokenizer.save_pretrained(args.merge_lora_model_path) |
54 | 144 |
|
55 | 145 |
|
56 | 146 | if __name__ == "__main__":
|
|
0 commit comments