46
46
LlamaModel ,
47
47
ModelArgs ,
48
48
)
49
-
50
- from executorch .examples .qualcomm .utils import make_quantizer
49
+ from executorch .examples .qualcomm .oss_scripts .llama .range_setting_pt2e import (
50
+ reverse_quantize_module_swap ,
51
+ WrappedLlamaModel ,
52
+ compute_scales ,
53
+ set_scales ,
54
+ make_custom_quantizer ,
55
+ )
51
56
52
57
from lm_eval .evaluator import simple_evaluate
53
58
54
59
from pytorch_tokenizers import get_tokenizer
55
60
56
- from torchao .quantization .pt2e import MinMaxObserver
57
61
from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
58
62
from torchao .quantization .pt2e .quantizer import QuantizationSpec
59
63
@@ -87,7 +91,6 @@ def forward(
87
91
)
88
92
return self .model .forward (tokens , self .atten_mask )
89
93
90
-
91
94
def add_mse_weight_observer (quant_dtype , quantizer ):
92
95
weight_dtype = (
93
96
torch .int4
@@ -118,21 +121,15 @@ def add_mse_weight_observer(quant_dtype, quantizer):
118
121
def gen_eval_wrapper (model_name , args ):
119
122
tokenizer = get_tokenizer (args .tokenizer_path )
120
123
with open (args .params ) as f :
121
- kv_config = ModelArgs (** json .load (f ))
124
+ prefill_config = ModelArgs (** json .load (f ))
122
125
# TODO: support batch inputs if necessary
123
- kv_config .max_batch_size = 1
124
- kv_config .max_seq_len = args .max_seq_length
125
- kv_config .use_kv_cache = True
126
-
127
- prefill_config = copy .copy (kv_config )
126
+ prefill_config .max_batch_size = 1
128
127
prefill_config .max_seq_len = args .max_seq_length
129
- prefill_config .use_kv_cache = (
130
- False if args .max_seq_length == args .prefill_ar_len else True
131
- )
132
- config = prefill_config
128
+ prefill_config .use_kv_cache = False
129
+ print (prefill_config .hidden_dim )
133
130
use_i64_token = args .embedding_quantize is not None
134
131
model = LlamaModel (
135
- config ,
132
+ prefill_config ,
136
133
ar_len = args .prefill_ar_len ,
137
134
output_new_cache_only = True ,
138
135
output_cache = False ,
@@ -173,20 +170,30 @@ def permute(w, heads):
173
170
if "model" in state_dict :
174
171
state_dict = state_dict ["model" ]
175
172
173
+ tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
174
+ tokens = tokens .to (device = args .device )
175
+ atten_mask = atten_mask .to (device = args .device )
176
+ atten_mask = atten_mask .to (dtype = torch .float )
177
+ inputs = (tokens , atten_mask )
178
+
179
+ model = model .to (dtype = torch .float )
180
+ model = model .to (device = args .device )
181
+
182
+ scales_state_dict = dict ()
183
+ if args .range_setting == "mse_with_act_loss" :
184
+ wrapped_model = WrappedLlamaModel (
185
+ model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
186
+ )
187
+ scales_state_dict = compute_scales (wrapped_model , tokens , 1600 ) # want to use different tokens for calibration!
188
+ reverse_quantize_module_swap (wrapped_model )
189
+
176
190
for layer in model .layers :
177
191
if getattr (layer .attention , "prepare_sha" , None ):
178
192
layer .attention .prepare_sha ()
179
193
if getattr (layer .feed_forward , "prepare_feedfoward_conv" , None ):
180
194
layer .feed_forward .prepare_feedfoward_conv ()
181
195
182
- model .to (dtype = torch .float )
183
- model .to (device = args .device )
184
-
185
- tokens , atten_mask = model .get_example_inputs (use_kv_cache = False )
186
- tokens = tokens .to (device = args .device )
187
- atten_mask = atten_mask .to (device = args .device )
188
- atten_mask = atten_mask .to (dtype = torch .float )
189
- inputs = (tokens , atten_mask )
196
+ model = model .to (dtype = torch .float )
190
197
191
198
if args .embedding_quantize :
192
199
model = get_quant_embedding_transform (
@@ -195,35 +202,28 @@ def permute(w, heads):
195
202
196
203
model = convert_linear_to_conv2d (model )
197
204
198
- if args .ptq :
205
+ if args .ptq is not None :
199
206
quant_dtype = getattr (QuantDtype , f"use_{ args .ptq } " )
200
207
201
208
custom_annotations = (annotate_matmul_16a8w ,)
202
209
if args .llama_model == "stories110m" :
203
210
custom_annotations = custom_annotations + (
204
211
annotate_linear_16a8w_in_affine_layer ,
205
212
)
206
- quantizer = make_quantizer (
207
- quant_dtype = quant_dtype ,
208
- per_channel_conv = True ,
209
- per_channel_linear = True ,
210
- act_observer = MinMaxObserver ,
211
- )
212
- quantizer .add_custom_quant_annotations (custom_annotations )
213
213
214
- if args .range_setting == "mse_weight" :
215
- add_mse_weight_observer (quant_dtype , quantizer )
214
+ quantizer = make_custom_quantizer (quant_dtype , args .range_setting , custom_annotations , args .quant_linear_only )
216
215
217
216
with torch .no_grad ():
217
+ logging .info ("Starting export..." )
218
218
model = torch .export .export (model , inputs , strict = True ).module ()
219
219
if quant_dtype == QuantDtype .use_16a4w_block :
220
220
conv_nodes = [n for n in model .graph .nodes if "conv" in n .name ]
221
221
block_size_map = {n .name : (1 , 64 , 1 , 1 ) for n in conv_nodes }
222
222
quantizer .set_block_size_map (block_size_map )
223
-
223
+ logging . info ( "Finished export, adding observers (prepare_pt2e)..." )
224
224
model = prepare_pt2e (model , quantizer )
225
225
226
- logging .info ("Quantizing the model ..." )
226
+ logging .info ("Observers added, starting calibration ..." )
227
227
228
228
calibrate (
229
229
inputs ,
@@ -236,7 +236,24 @@ def permute(w, heads):
236
236
use_i64_token = use_i64_token ,
237
237
)
238
238
239
+ if args .range_setting == "mse_with_act_loss" :
240
+ # scales_state_dict = torch.load("scales_state_dict.pth")
241
+ set_scales (model , scales_state_dict )
242
+
243
+ logging .info ("Quantizing the model..." )
239
244
model = convert_pt2e (model )
245
+ logging .info ("Quantization complete! Here is some sample generated text:" )
246
+
247
+ calibrate (
248
+ inputs ,
249
+ "Could you tell me about Facebook?" ,
250
+ model ,
251
+ tokenizer = tokenizer ,
252
+ ar_len = args .prefill_ar_len ,
253
+ max_seq_len = args .max_seq_len ,
254
+ kv_updater = None ,
255
+ use_i64_token = use_i64_token ,
256
+ )
240
257
241
258
model = WrappedLlamaModel (
242
259
model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
@@ -248,7 +265,7 @@ def permute(w, heads):
248
265
max_seq_length = args .calibration_seq_length ,
249
266
use_kv_cache = args .use_kv_cache ,
250
267
generate_full_logits = args .generate_full_logits ,
251
- enable_dynamic_shape = args . enable_dynamic_shape ,
268
+ enable_dynamic_shape = False ,
252
269
)
253
270
254
271
@@ -271,7 +288,7 @@ def eval_llama(
271
288
model = eval_wrapper ,
272
289
tasks = args .tasks ,
273
290
num_fewshot = args .num_fewshot ,
274
- limit = args .limit ,
291
+ limit = args .fraction ,
275
292
)
276
293
277
294
for task , res in eval_results ["results" ].items ():
@@ -291,13 +308,18 @@ def main() -> None:
291
308
)
292
309
parser .add_argument (
293
310
"--range_setting" ,
294
- help = "Choose which range setting method (e.g. mse_weight ). If not specified, will do minmax for weights and activations " ,
311
+ help = "Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss ). If not specified, defaults to minmax" ,
295
312
type = str ,
296
313
)
297
314
parser .add_argument (
298
- "--limit" ,
299
- help = "the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples" ,
300
- type = str ,
315
+ "--fraction" ,
316
+ help = "the fraction of examples per task (only use this for testing)" ,
317
+ type = float ,
318
+ )
319
+ parser .add_argument (
320
+ "--quant_linear_only" ,
321
+ help = "if you select this option we quantize linear layers only. If ptq arg not specified then defaults to 16a4w" ,
322
+ action = 'store_true' ,
301
323
)
302
324
303
325
args = parser .parse_args ()
0 commit comments