Skip to content

Commit a457091

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
Implemented range setting in QNN llama flow
Summary: `llama.py` now has the `--range_setting` flag, for which there are the options `mse_weight_only` and `mse_with_act_loss`. There is also an eval script for computing perplexity called `eval_llama_qnn.py` (for faster eval, try seq length 1024). This script also has a flag --quant_linear_only to only quantize linear/conv nodes, to run faster experiments. Commands: ```python examples/qualcomm/oss_scripts/llama/llama.py --checkpoint {MODEL_DIR}/consolidated.00.pth --params {MODEL_DIR}/params.json --tokenizer_path {MODEL_DIR}/tokenizer.model --max_seq_length 128 --ptq 16a4w --range_setting mse_with_act_loss``` ```python examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py --checkpoint {MODEL_DIR}/consolidated.00.pth --params {MODEL_DIR}/params.json --tokenizer_path {MODEL_DIR}/tokenizer.model --max_seq_length 128 --ptq 16a4w --range_setting mse_with_act_loss``` Rollback Plan: Differential Revision: D78127727
1 parent 986b447 commit a457091

File tree

5 files changed

+405
-51
lines changed

5 files changed

+405
-51
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ python_library(
3434
],
3535
)
3636

37+
python_library(
38+
name = "range_setting_pt2e",
39+
srcs = [
40+
"range_setting_pt2e.py",
41+
],
42+
deps = [
43+
"//caffe2:torch",
44+
],
45+
)
46+
3747
python_binary(
3848
name = "llama",
3949
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
@@ -42,6 +52,7 @@ python_binary(
4252
],
4353
deps = [
4454
":llama_lib",
55+
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
4556
],
4657
)
4758

@@ -55,6 +66,7 @@ python_binary(
5566
deps = [
5667
":llama_lib",
5768
"//executorch/examples/models/llama:eval_library",
69+
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
5870
"fbsource//third-party/pypi/lm-eval:lm-eval",
5971
],
6072
)

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,18 @@
4646
LlamaModel,
4747
ModelArgs,
4848
)
49-
50-
from executorch.examples.qualcomm.utils import make_quantizer
49+
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import (
50+
reverse_quantize_module_swap,
51+
WrappedLlamaModel,
52+
compute_scales,
53+
set_scales,
54+
make_custom_quantizer,
55+
)
5156

5257
from lm_eval.evaluator import simple_evaluate
5358

5459
from pytorch_tokenizers import get_tokenizer
5560

56-
from torchao.quantization.pt2e import MinMaxObserver
5761
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
5862
from torchao.quantization.pt2e.quantizer import QuantizationSpec
5963

@@ -87,7 +91,6 @@ def forward(
8791
)
8892
return self.model.forward(tokens, self.atten_mask)
8993

90-
9194
def add_mse_weight_observer(quant_dtype, quantizer):
9295
weight_dtype = (
9396
torch.int4
@@ -118,21 +121,15 @@ def add_mse_weight_observer(quant_dtype, quantizer):
118121
def gen_eval_wrapper(model_name, args):
119122
tokenizer = get_tokenizer(args.tokenizer_path)
120123
with open(args.params) as f:
121-
kv_config = ModelArgs(**json.load(f))
124+
prefill_config = ModelArgs(**json.load(f))
122125
# TODO: support batch inputs if necessary
123-
kv_config.max_batch_size = 1
124-
kv_config.max_seq_len = args.max_seq_length
125-
kv_config.use_kv_cache = True
126-
127-
prefill_config = copy.copy(kv_config)
126+
prefill_config.max_batch_size = 1
128127
prefill_config.max_seq_len = args.max_seq_length
129-
prefill_config.use_kv_cache = (
130-
False if args.max_seq_length == args.prefill_ar_len else True
131-
)
132-
config = prefill_config
128+
prefill_config.use_kv_cache = False
129+
print(prefill_config.hidden_dim)
133130
use_i64_token = args.embedding_quantize is not None
134131
model = LlamaModel(
135-
config,
132+
prefill_config,
136133
ar_len=args.prefill_ar_len,
137134
output_new_cache_only=True,
138135
output_cache=False,
@@ -173,20 +170,30 @@ def permute(w, heads):
173170
if "model" in state_dict:
174171
state_dict = state_dict["model"]
175172

173+
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
174+
tokens = tokens.to(device=args.device)
175+
atten_mask = atten_mask.to(device=args.device)
176+
atten_mask = atten_mask.to(dtype=torch.float)
177+
inputs = (tokens, atten_mask)
178+
179+
model = model.to(dtype=torch.float)
180+
model = model.to(device=args.device)
181+
182+
scales_state_dict = dict()
183+
if args.range_setting == "mse_with_act_loss":
184+
wrapped_model = WrappedLlamaModel(
185+
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
186+
)
187+
scales_state_dict = compute_scales(wrapped_model, tokens, 1600) # want to use different tokens for calibration!
188+
reverse_quantize_module_swap(wrapped_model)
189+
176190
for layer in model.layers:
177191
if getattr(layer.attention, "prepare_sha", None):
178192
layer.attention.prepare_sha()
179193
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
180194
layer.feed_forward.prepare_feedfoward_conv()
181195

182-
model.to(dtype=torch.float)
183-
model.to(device=args.device)
184-
185-
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
186-
tokens = tokens.to(device=args.device)
187-
atten_mask = atten_mask.to(device=args.device)
188-
atten_mask = atten_mask.to(dtype=torch.float)
189-
inputs = (tokens, atten_mask)
196+
model = model.to(dtype=torch.float)
190197

191198
if args.embedding_quantize:
192199
model = get_quant_embedding_transform(
@@ -195,35 +202,28 @@ def permute(w, heads):
195202

196203
model = convert_linear_to_conv2d(model)
197204

198-
if args.ptq:
205+
if args.ptq is not None:
199206
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
200207

201208
custom_annotations = (annotate_matmul_16a8w,)
202209
if args.llama_model == "stories110m":
203210
custom_annotations = custom_annotations + (
204211
annotate_linear_16a8w_in_affine_layer,
205212
)
206-
quantizer = make_quantizer(
207-
quant_dtype=quant_dtype,
208-
per_channel_conv=True,
209-
per_channel_linear=True,
210-
act_observer=MinMaxObserver,
211-
)
212-
quantizer.add_custom_quant_annotations(custom_annotations)
213213

214-
if args.range_setting == "mse_weight":
215-
add_mse_weight_observer(quant_dtype, quantizer)
214+
quantizer = make_custom_quantizer(quant_dtype, args.range_setting, custom_annotations, args.quant_linear_only)
216215

217216
with torch.no_grad():
217+
logging.info("Starting export...")
218218
model = torch.export.export(model, inputs, strict=True).module()
219219
if quant_dtype == QuantDtype.use_16a4w_block:
220220
conv_nodes = [n for n in model.graph.nodes if "conv" in n.name]
221221
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
222222
quantizer.set_block_size_map(block_size_map)
223-
223+
logging.info("Finished export, adding observers (prepare_pt2e)...")
224224
model = prepare_pt2e(model, quantizer)
225225

226-
logging.info("Quantizing the model...")
226+
logging.info("Observers added, starting calibration...")
227227

228228
calibrate(
229229
inputs,
@@ -236,7 +236,24 @@ def permute(w, heads):
236236
use_i64_token=use_i64_token,
237237
)
238238

239+
if args.range_setting == "mse_with_act_loss":
240+
# scales_state_dict = torch.load("scales_state_dict.pth")
241+
set_scales(model, scales_state_dict)
242+
243+
logging.info("Quantizing the model...")
239244
model = convert_pt2e(model)
245+
logging.info("Quantization complete! Here is some sample generated text:")
246+
247+
calibrate(
248+
inputs,
249+
"Could you tell me about Facebook?",
250+
model,
251+
tokenizer=tokenizer,
252+
ar_len=args.prefill_ar_len,
253+
max_seq_len=args.max_seq_len,
254+
kv_updater=None,
255+
use_i64_token=use_i64_token,
256+
)
240257

241258
model = WrappedLlamaModel(
242259
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
@@ -248,7 +265,7 @@ def permute(w, heads):
248265
max_seq_length=args.calibration_seq_length,
249266
use_kv_cache=args.use_kv_cache,
250267
generate_full_logits=args.generate_full_logits,
251-
enable_dynamic_shape=args.enable_dynamic_shape,
268+
enable_dynamic_shape=False,
252269
)
253270

254271

@@ -271,7 +288,7 @@ def eval_llama(
271288
model=eval_wrapper,
272289
tasks=args.tasks,
273290
num_fewshot=args.num_fewshot,
274-
limit=args.limit,
291+
limit=args.fraction,
275292
)
276293

277294
for task, res in eval_results["results"].items():
@@ -291,13 +308,18 @@ def main() -> None:
291308
)
292309
parser.add_argument(
293310
"--range_setting",
294-
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
311+
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax",
295312
type=str,
296313
)
297314
parser.add_argument(
298-
"--limit",
299-
help="the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples",
300-
type=str,
315+
"--fraction",
316+
help="the fraction of examples per task (only use this for testing)",
317+
type=float,
318+
)
319+
parser.add_argument(
320+
"--quant_linear_only",
321+
help="if you select this option we quantize linear layers only. If ptq arg not specified then defaults to 16a4w",
322+
action='store_true',
301323
)
302324

303325
args = parser.parse_args()

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@
6363
LlamaModel,
6464
ModelArgs,
6565
)
66+
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import (
67+
reverse_quantize_module_swap,
68+
WrappedLlamaModel,
69+
compute_scales,
70+
set_scales,
71+
make_custom_quantizer
72+
)
73+
6674
from executorch.examples.qualcomm.utils import (
6775
make_output_dir,
6876
make_quantizer,
@@ -380,15 +388,9 @@ def _tag_ios(self, node, fixed_point_type):
380388

381389
return quant_io_type
382390

383-
def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
391+
def quantize(self, quant_dtype, args, tokenizer, custom_annotations=(), scales_state_dict=None):
384392
self.quant_dtype = quant_dtype
385-
quantizer = make_quantizer(
386-
quant_dtype=quant_dtype,
387-
per_channel_conv=True,
388-
per_channel_linear=True,
389-
act_observer=MinMaxObserver,
390-
)
391-
quantizer.add_custom_quant_annotations(custom_annotations)
393+
quantizer = make_custom_quantizer(quant_dtype, args.range_setting, custom_annotations)
392394

393395
self.has_quant_io = True
394396
fx_graph_module = None
@@ -419,6 +421,10 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
419421
use_i64_token=args.embedding_quantize is not None,
420422
)
421423

424+
if scales_state_dict:
425+
# scales_state_dict = torch.load("scales_state_dict.pth")
426+
set_scales(fx_graph_module, scales_state_dict)
427+
422428
self.llama_graph_module = convert_pt2e(fx_graph_module)
423429

424430
def lowering_modules(
@@ -597,6 +603,18 @@ def permute(w, heads):
597603
end_load_ts = time.time()
598604
logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}")
599605

606+
scales_state_dict = dict()
607+
if args.range_setting == "mse_with_act_loss":
608+
model = llama_instance_list[0]
609+
model.to(torch.float)
610+
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
611+
atten_mask.to(torch.float)
612+
wrapped_model = WrappedLlamaModel(
613+
model, atten_mask, model.use_kv_cache, args.max_seq_len, args.device
614+
)
615+
scales_state_dict = compute_scales(wrapped_model, tokens, 1600) # want to use different tokens for calibration!
616+
reverse_quantize_module_swap(wrapped_model)
617+
600618
for llama_instance in llama_instance_list:
601619
for layer in llama_instance.layers:
602620
if getattr(layer.attention, "prepare_sha", None):
@@ -658,6 +676,7 @@ def permute(w, heads):
658676
args=args,
659677
tokenizer=tokenizer,
660678
custom_annotations=custom_annotations,
679+
scales_state_dict=scales_state_dict,
661680
)
662681
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
663682
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
@@ -1062,6 +1081,12 @@ def _build_parser():
10621081
type=int,
10631082
)
10641083

1084+
parser.add_argument(
1085+
"--range_setting",
1086+
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax",
1087+
type=str,
1088+
)
1089+
10651090
parser.add_argument("-v", "--verbose", action="store_true")
10661091

10671092
return parser

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ def apply_rotary_emb_single(
2525
x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
2626
# broadcast for batch_prefill mode input x
2727
if x.dim() == 4:
28-
freqs_cos = freqs_cos[None, None, :, :]
29-
freqs_sin = freqs_sin[None, None, :, :]
28+
# freqs_cos = freqs_cos[None, None, :, :]
29+
# freqs_sin = freqs_sin[None, None, :, :]
30+
freqs_cos = freqs_cos[None, :, None, :]
31+
freqs_sin = freqs_sin[None, :, None, :]
3032
x_out_r = x_r * freqs_cos - x_i * freqs_sin
3133
x_out_i = x_r * freqs_sin + x_i * freqs_cos
3234

0 commit comments

Comments
 (0)