Skip to content

Commit d55c96d

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
Implemented range setting in QNN llama flow (#12377)
Summary: `llama.py` now has the `--range_setting` flag, for which there are the options `mse_weight_only` and `mse_with_act_loss`. There is also an eval script for computing perplexity called `eval_llama_qnn.py` (for faster eval, try seq length 1024). This script also has a flag --quant_linear_only to only quantize linear/conv nodes, to run faster experiments. Commands: ```python examples/qualcomm/oss_scripts/llama/llama.py --checkpoint {MODEL_DIR}/consolidated.00.pth --params {MODEL_DIR}/params.json --tokenizer_path {MODEL_DIR}/tokenizer.model --max_seq_length 128 --ptq 16a4w --range_setting mse_with_act_loss``` ```python examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py --checkpoint {MODEL_DIR}/consolidated.00.pth --params {MODEL_DIR}/params.json --tokenizer_path {MODEL_DIR}/tokenizer.model --max_seq_length 128 --ptq 16a4w --range_setting mse_with_act_loss``` Rollback Plan: Differential Revision: D78127727
1 parent dd4488d commit d55c96d

File tree

5 files changed

+424
-57
lines changed

5 files changed

+424
-57
lines changed

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ python_library(
3434
],
3535
)
3636

37+
python_library(
38+
name = "range_setting_pt2e",
39+
srcs = [
40+
"range_setting_pt2e.py",
41+
],
42+
deps = [
43+
"//caffe2:torch",
44+
],
45+
)
46+
3747
python_binary(
3848
name = "llama",
3949
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
@@ -42,6 +52,7 @@ python_binary(
4252
],
4353
deps = [
4454
":llama_lib",
55+
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
4556
],
4657
)
4758

@@ -55,6 +66,7 @@ python_binary(
5566
deps = [
5667
":llama_lib",
5768
"//executorch/examples/models/llama:eval_library",
69+
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
5870
"fbsource//third-party/pypi/lm-eval:lm-eval",
5971
],
6072
)

examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py

Lines changed: 64 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,18 @@
4646
LlamaModel,
4747
ModelArgs,
4848
)
49-
50-
from executorch.examples.qualcomm.utils import make_quantizer
49+
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import (
50+
reverse_quantize_module_swap,
51+
WrappedLlamaModel,
52+
compute_scales,
53+
set_scales,
54+
make_custom_quantizer,
55+
)
5156

5257
from lm_eval.evaluator import simple_evaluate
5358

5459
from pytorch_tokenizers import get_tokenizer
5560

56-
from torchao.quantization.pt2e import MinMaxObserver
5761
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
5862
from torchao.quantization.pt2e.quantizer import QuantizationSpec
5963

@@ -87,7 +91,6 @@ def forward(
8791
)
8892
return self.model.forward(tokens, self.atten_mask)
8993

90-
9194
def add_mse_weight_observer(quant_dtype, quantizer):
9295
weight_dtype = (
9396
torch.int4
@@ -118,21 +121,14 @@ def add_mse_weight_observer(quant_dtype, quantizer):
118121
def gen_eval_wrapper(model_name, args):
119122
tokenizer = get_tokenizer(args.tokenizer_path)
120123
with open(args.params) as f:
121-
kv_config = ModelArgs(**json.load(f))
124+
prefill_config = ModelArgs(**json.load(f))
122125
# TODO: support batch inputs if necessary
123-
kv_config.max_batch_size = 1
124-
kv_config.max_seq_len = args.max_seq_length
125-
kv_config.use_kv_cache = True
126-
127-
prefill_config = copy.copy(kv_config)
126+
prefill_config.max_batch_size = 1
128127
prefill_config.max_seq_len = args.max_seq_length
129-
prefill_config.use_kv_cache = (
130-
False if args.max_seq_length == args.prefill_ar_len else True
131-
)
132-
config = prefill_config
128+
prefill_config.use_kv_cache = False
133129
use_i64_token = args.embedding_quantize is not None
134130
model = LlamaModel(
135-
config,
131+
prefill_config,
136132
ar_len=args.prefill_ar_len,
137133
output_new_cache_only=True,
138134
output_cache=False,
@@ -173,20 +169,32 @@ def permute(w, heads):
173169
if "model" in state_dict:
174170
state_dict = state_dict["model"]
175171

172+
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
173+
tokens = tokens.to(device=args.device)
174+
atten_mask = atten_mask.to(device=args.device)
175+
atten_mask = atten_mask.to(dtype=torch.float)
176+
inputs = (tokens, atten_mask)
177+
178+
model = model.to(dtype=torch.float)
179+
model = model.to(device=args.device)
180+
181+
scales_state_dict = dict()
182+
if args.range_setting == "mse_with_act_loss":
183+
wrapped_model = WrappedLlamaModel(
184+
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
185+
)
186+
scales_state_dict = compute_scales(wrapped_model, tokens, 1600) # want to use different tokens for calibration!
187+
torch.save(scales_state_dict, "scales_state_dict.pth")
188+
logging.info("Saved scales to scales_state_dict.pth!")
189+
reverse_quantize_module_swap(wrapped_model)
190+
176191
for layer in model.layers:
177192
if getattr(layer.attention, "prepare_sha", None):
178193
layer.attention.prepare_sha()
179194
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
180195
layer.feed_forward.prepare_feedfoward_conv()
181196

182-
model.to(dtype=torch.float)
183-
model.to(device=args.device)
184-
185-
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
186-
tokens = tokens.to(device=args.device)
187-
atten_mask = atten_mask.to(device=args.device)
188-
atten_mask = atten_mask.to(dtype=torch.float)
189-
inputs = (tokens, atten_mask)
197+
model = model.to(dtype=torch.float)
190198

191199
if args.embedding_quantize:
192200
model = get_quant_embedding_transform(
@@ -195,35 +203,28 @@ def permute(w, heads):
195203

196204
model = convert_linear_to_conv2d(model)
197205

198-
if args.ptq:
206+
if args.ptq is not None:
199207
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")
200208

201209
custom_annotations = (annotate_matmul_16a8w,)
202210
if args.llama_model == "stories110m":
203211
custom_annotations = custom_annotations + (
204212
annotate_linear_16a8w_in_affine_layer,
205213
)
206-
quantizer = make_quantizer(
207-
quant_dtype=quant_dtype,
208-
per_channel_conv=True,
209-
per_channel_linear=True,
210-
act_observer=MinMaxObserver,
211-
)
212-
quantizer.add_custom_quant_annotations(custom_annotations)
213214

214-
if args.range_setting == "mse_weight":
215-
add_mse_weight_observer(quant_dtype, quantizer)
215+
quantizer = make_custom_quantizer(quant_dtype, args.range_setting, custom_annotations, args.quant_linear_only)
216216

217217
with torch.no_grad():
218+
logging.info("Starting export...")
218219
model = torch.export.export(model, inputs, strict=True).module()
219220
if quant_dtype == QuantDtype.use_16a4w_block:
220221
conv_nodes = [n for n in model.graph.nodes if "conv" in n.name]
221222
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
222223
quantizer.set_block_size_map(block_size_map)
223-
224+
logging.info("Finished export, adding observers (prepare_pt2e)...")
224225
model = prepare_pt2e(model, quantizer)
225226

226-
logging.info("Quantizing the model...")
227+
logging.info("Observers added, starting calibration...")
227228

228229
calibrate(
229230
inputs,
@@ -236,7 +237,24 @@ def permute(w, heads):
236237
use_i64_token=use_i64_token,
237238
)
238239

240+
if args.range_setting == "mse_with_act_loss":
241+
# scales_state_dict = torch.load("scales_state_dict.pth")
242+
set_scales(model, scales_state_dict)
243+
244+
logging.info("Quantizing the model...")
239245
model = convert_pt2e(model)
246+
logging.info("Quantization complete! Here is some sample generated text:")
247+
248+
calibrate(
249+
inputs,
250+
"Could you tell me about Facebook?",
251+
model,
252+
tokenizer=tokenizer,
253+
ar_len=args.prefill_ar_len,
254+
max_seq_len=args.max_seq_len,
255+
kv_updater=None,
256+
use_i64_token=use_i64_token,
257+
)
240258

241259
model = WrappedLlamaModel(
242260
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
@@ -248,7 +266,7 @@ def permute(w, heads):
248266
max_seq_length=args.calibration_seq_length,
249267
use_kv_cache=args.use_kv_cache,
250268
generate_full_logits=args.generate_full_logits,
251-
enable_dynamic_shape=args.enable_dynamic_shape,
269+
enable_dynamic_shape=False,
252270
)
253271

254272

@@ -271,7 +289,7 @@ def eval_llama(
271289
model=eval_wrapper,
272290
tasks=args.tasks,
273291
num_fewshot=args.num_fewshot,
274-
limit=args.limit,
292+
limit=args.fraction,
275293
)
276294

277295
for task, res in eval_results["results"].items():
@@ -291,13 +309,18 @@ def main() -> None:
291309
)
292310
parser.add_argument(
293311
"--range_setting",
294-
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
312+
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax",
295313
type=str,
296314
)
297315
parser.add_argument(
298-
"--limit",
299-
help="the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples",
300-
type=str,
316+
"--fraction",
317+
help="the fraction of examples per task (only use this for testing)",
318+
type=float,
319+
)
320+
parser.add_argument(
321+
"--quant_linear_only",
322+
help="if you select this option we quantize linear layers only. If ptq arg not specified then defaults to 16a4w",
323+
action='store_true',
301324
)
302325

303326
args = parser.parse_args()

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@
6363
LlamaModel,
6464
ModelArgs,
6565
)
66+
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import (
67+
reverse_quantize_module_swap,
68+
WrappedLlamaModel,
69+
compute_scales,
70+
set_scales,
71+
make_custom_quantizer
72+
)
73+
6674
from executorch.examples.qualcomm.utils import (
6775
make_output_dir,
6876
make_quantizer,
@@ -380,15 +388,9 @@ def _tag_ios(self, node, fixed_point_type):
380388

381389
return quant_io_type
382390

383-
def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
391+
def quantize(self, quant_dtype, args, tokenizer, custom_annotations=(), scales_state_dict=None):
384392
self.quant_dtype = quant_dtype
385-
quantizer = make_quantizer(
386-
quant_dtype=quant_dtype,
387-
per_channel_conv=True,
388-
per_channel_linear=True,
389-
act_observer=MinMaxObserver,
390-
)
391-
quantizer.add_custom_quant_annotations(custom_annotations)
393+
quantizer = make_custom_quantizer(quant_dtype, args.range_setting, custom_annotations)
392394

393395
self.has_quant_io = True
394396
fx_graph_module = None
@@ -408,6 +410,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
408410
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
409411

410412
logging.info("Quantizing the model...")
413+
411414
calibrate(
412415
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
413416
args.prompt[0],
@@ -419,6 +422,9 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
419422
use_i64_token=args.embedding_quantize is not None,
420423
)
421424

425+
if scales_state_dict:
426+
set_scales(fx_graph_module, scales_state_dict)
427+
422428
self.llama_graph_module = convert_pt2e(fx_graph_module)
423429

424430
def lowering_modules(
@@ -597,6 +603,27 @@ def permute(w, heads):
597603
end_load_ts = time.time()
598604
logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}")
599605

606+
scales_state_dict = dict()
607+
if args.range_setting == "mse_with_act_loss":
608+
try:
609+
scales_state_dict = torch.load("scales_state_dict.pth", map_location=torch.device('cpu'))
610+
logging.info("Loaded scales_state_dict from file")
611+
except:
612+
logging.info("Computing scales using activation loss range setting")
613+
model = llama_instance_list[1]
614+
model.to(torch.float)
615+
ar_len = model.ar_len
616+
model.ar_len = model.max_seq_len
617+
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
618+
atten_mask.to(torch.float)
619+
print(atten_mask.shape)
620+
wrapped_model = WrappedLlamaModel(
621+
model, atten_mask, model.use_kv_cache, args.max_seq_len, args.device
622+
)
623+
scales_state_dict = compute_scales(wrapped_model, tokens, 1600) # want to use different tokens for calibration!
624+
reverse_quantize_module_swap(wrapped_model)
625+
model.ar_len = ar_len
626+
600627
for llama_instance in llama_instance_list:
601628
for layer in llama_instance.layers:
602629
if getattr(layer.attention, "prepare_sha", None):
@@ -658,6 +685,7 @@ def permute(w, heads):
658685
args=args,
659686
tokenizer=tokenizer,
660687
custom_annotations=custom_annotations,
688+
scales_state_dict=scales_state_dict,
661689
)
662690
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
663691
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
@@ -668,12 +696,12 @@ def permute(w, heads):
668696
kv_quant_attrs[output_indices] = output.args[1:]
669697
output_indices += 1
670698
break
671-
custom_annotations = custom_annotations + (
672-
partial(
673-
annotate_prefill_kv_output,
674-
kv_quant_attrs=kv_quant_attrs,
675-
),
676-
)
699+
# custom_annotations = custom_annotations + (
700+
# partial(
701+
# annotate_prefill_kv_output,
702+
# kv_quant_attrs=kv_quant_attrs,
703+
# ),
704+
# ) # temporarily remove annotate_prefill_kv_output
677705
llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True
678706
llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][
679707
"get_quant_io_dtype_fn"
@@ -1062,6 +1090,12 @@ def _build_parser():
10621090
type=int,
10631091
)
10641092

1093+
parser.add_argument(
1094+
"--range_setting",
1095+
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax",
1096+
type=str,
1097+
)
1098+
10651099
parser.add_argument("-v", "--verbose", action="store_true")
10661100

10671101
return parser

examples/qualcomm/oss_scripts/llama/model/static_llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ def apply_rotary_emb_single(
2525
x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
2626
# broadcast for batch_prefill mode input x
2727
if x.dim() == 4:
28-
freqs_cos = freqs_cos[None, None, :, :]
29-
freqs_sin = freqs_sin[None, None, :, :]
28+
# freqs_cos = freqs_cos[None, None, :, :]
29+
# freqs_sin = freqs_sin[None, None, :, :]
30+
freqs_cos = freqs_cos[None, :, None, :]
31+
freqs_sin = freqs_sin[None, :, None, :]
3032
x_out_r = x_r * freqs_cos - x_i * freqs_sin
3133
x_out_i = x_r * freqs_sin + x_i * freqs_cos
3234

0 commit comments

Comments
 (0)