46
46
LlamaModel ,
47
47
ModelArgs ,
48
48
)
49
-
50
- from executorch .examples .qualcomm .utils import make_quantizer
49
+ from executorch .examples .qualcomm .oss_scripts .llama .range_setting_pt2e import (
50
+ reverse_quantize_module_swap ,
51
+ WrappedLlamaModel ,
52
+ compute_scales ,
53
+ set_scales ,
54
+ make_custom_quantizer ,
55
+ )
51
56
52
57
from lm_eval .evaluator import simple_evaluate
53
58
54
59
from pytorch_tokenizers import get_tokenizer
55
60
56
- from torchao .quantization .pt2e import MinMaxObserver
57
61
from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
58
62
from torchao .quantization .pt2e .quantizer import QuantizationSpec
59
63
@@ -87,7 +91,6 @@ def forward(
87
91
)
88
92
return self .model .forward (tokens , self .atten_mask )
89
93
90
-
91
94
def add_mse_weight_observer (quant_dtype , quantizer ):
92
95
weight_dtype = (
93
96
torch .int4
@@ -118,21 +121,14 @@ def add_mse_weight_observer(quant_dtype, quantizer):
118
121
def gen_eval_wrapper (model_name , args ):
119
122
tokenizer = get_tokenizer (args .tokenizer_path )
120
123
with open (args .params ) as f :
121
- kv_config = ModelArgs (** json .load (f ))
124
+ prefill_config = ModelArgs (** json .load (f ))
122
125
# TODO: support batch inputs if necessary
123
- kv_config .max_batch_size = 1
124
- kv_config .max_seq_len = args .max_seq_length
125
- kv_config .use_kv_cache = True
126
-
127
- prefill_config = copy .copy (kv_config )
126
+ prefill_config .max_batch_size = 1
128
127
prefill_config .max_seq_len = args .max_seq_length
129
- prefill_config .use_kv_cache = (
130
- False if args .max_seq_length == args .prefill_ar_len else True
131
- )
132
- config = prefill_config
128
+ prefill_config .use_kv_cache = False
133
129
use_i64_token = args .embedding_quantize is not None
134
130
model = LlamaModel (
135
- config ,
131
+ prefill_config ,
136
132
ar_len = args .prefill_ar_len ,
137
133
output_new_cache_only = True ,
138
134
output_cache = False ,
@@ -173,20 +169,32 @@ def permute(w, heads):
173
169
if "model" in state_dict :
174
170
state_dict = state_dict ["model" ]
175
171
172
+ tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
173
+ tokens = tokens .to (device = args .device )
174
+ atten_mask = atten_mask .to (device = args .device )
175
+ atten_mask = atten_mask .to (dtype = torch .float )
176
+ inputs = (tokens , atten_mask )
177
+
178
+ model = model .to (dtype = torch .float )
179
+ model = model .to (device = args .device )
180
+
181
+ scales_state_dict = dict ()
182
+ if args .range_setting == "mse_with_act_loss" :
183
+ wrapped_model = WrappedLlamaModel (
184
+ model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
185
+ )
186
+ scales_state_dict = compute_scales (wrapped_model , tokens , 1600 ) # want to use different tokens for calibration!
187
+ torch .save (scales_state_dict , "scales_state_dict.pth" )
188
+ logging .info ("Saved scales to scales_state_dict.pth!" )
189
+ reverse_quantize_module_swap (wrapped_model )
190
+
176
191
for layer in model .layers :
177
192
if getattr (layer .attention , "prepare_sha" , None ):
178
193
layer .attention .prepare_sha ()
179
194
if getattr (layer .feed_forward , "prepare_feedfoward_conv" , None ):
180
195
layer .feed_forward .prepare_feedfoward_conv ()
181
196
182
- model .to (dtype = torch .float )
183
- model .to (device = args .device )
184
-
185
- tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
186
- tokens = tokens .to (device = args .device )
187
- atten_mask = atten_mask .to (device = args .device )
188
- atten_mask = atten_mask .to (dtype = torch .float )
189
- inputs = (tokens , atten_mask )
197
+ model = model .to (dtype = torch .float )
190
198
191
199
if args .embedding_quantize :
192
200
model = get_quant_embedding_transform (
@@ -195,35 +203,28 @@ def permute(w, heads):
195
203
196
204
model = convert_linear_to_conv2d (model )
197
205
198
- if args .ptq :
206
+ if args .ptq is not None :
199
207
quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
200
208
201
209
custom_annotations = (annotate_matmul_16a8w ,)
202
210
if args .llama_model == "stories110m" :
203
211
custom_annotations = custom_annotations + (
204
212
annotate_linear_16a8w_in_affine_layer ,
205
213
)
206
- quantizer = make_quantizer (
207
- quant_dtype = quant_dtype ,
208
- per_channel_conv = True ,
209
- per_channel_linear = True ,
210
- act_observer = MinMaxObserver ,
211
- )
212
- quantizer .add_custom_quant_annotations (custom_annotations )
213
214
214
- if args .range_setting == "mse_weight" :
215
- add_mse_weight_observer (quant_dtype , quantizer )
215
+ quantizer = make_custom_quantizer (quant_dtype , args .range_setting , custom_annotations , args .quant_linear_only )
216
216
217
217
with torch .no_grad ():
218
+ logging .info ("Starting export..." )
218
219
model = torch .export .export (model , inputs , strict = True ).module ()
219
220
if quant_dtype == QuantDtype .use_16a4w_block :
220
221
conv_nodes = [n for n in model .graph .nodes if "conv" in n .name ]
221
222
block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
222
223
quantizer .set_block_size_map (block_size_map )
223
-
224
+ logging . info ( "Finished export, adding observers (prepare_pt2e)..." )
224
225
model = prepare_pt2e (model , quantizer )
225
226
226
- logging .info ("Quantizing the model ..." )
227
+ logging .info ("Observers added, starting calibration ..." )
227
228
228
229
calibrate (
229
230
inputs ,
@@ -236,7 +237,24 @@ def permute(w, heads):
236
237
use_i64_token = use_i64_token ,
237
238
)
238
239
240
+ if args .range_setting == "mse_with_act_loss" :
241
+ # scales_state_dict = torch.load("scales_state_dict.pth")
242
+ set_scales (model , scales_state_dict )
243
+
244
+ logging .info ("Quantizing the model..." )
239
245
model = convert_pt2e (model )
246
+ logging .info ("Quantization complete! Here is some sample generated text:" )
247
+
248
+ calibrate (
249
+ inputs ,
250
+ "Could you tell me about Facebook?" ,
251
+ model ,
252
+ tokenizer = tokenizer ,
253
+ ar_len = args .prefill_ar_len ,
254
+ max_seq_len = args .max_seq_len ,
255
+ kv_updater = None ,
256
+ use_i64_token = use_i64_token ,
257
+ )
240
258
241
259
model = WrappedLlamaModel (
242
260
model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
@@ -248,7 +266,7 @@ def permute(w, heads):
248
266
max_seq_length = args .calibration_seq_length ,
249
267
use_kv_cache = args .use_kv_cache ,
250
268
generate_full_logits = args .generate_full_logits ,
251
- enable_dynamic_shape = args . enable_dynamic_shape ,
269
+ enable_dynamic_shape = False ,
252
270
)
253
271
254
272
@@ -271,7 +289,7 @@ def eval_llama(
271
289
model = eval_wrapper ,
272
290
tasks = args .tasks ,
273
291
num_fewshot = args .num_fewshot ,
274
- limit = args .limit ,
292
+ limit = args .fraction ,
275
293
)
276
294
277
295
for task , res in eval_results ["results" ].items ():
@@ -291,13 +309,18 @@ def main() -> None:
291
309
)
292
310
parser .add_argument (
293
311
"--range_setting" ,
294
- help = "Choose which range setting method (e.g. mse_weight ). If not specified, will do minmax for weights and activations " ,
312
+ help = "Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss ). If not specified, defaults to minmax" ,
295
313
type = str ,
296
314
)
297
315
parser .add_argument (
298
- "--limit" ,
299
- help = "the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples" ,
300
- type = str ,
316
+ "--fraction" ,
317
+ help = "the fraction of examples per task (only use this for testing)" ,
318
+ type = float ,
319
+ )
320
+ parser .add_argument (
321
+ "--quant_linear_only" ,
322
+ help = "if you select this option we quantize linear layers only. If ptq arg not specified then defaults to 16a4w" ,
323
+ action = 'store_true' ,
301
324
)
302
325
303
326
args = parser .parse_args ()
0 commit comments