diff --git a/examples/qualcomm/oss_scripts/llama/TARGETS b/examples/qualcomm/oss_scripts/llama/TARGETS index 9c5dd1ceaf9..264854d9bfc 100644 --- a/examples/qualcomm/oss_scripts/llama/TARGETS +++ b/examples/qualcomm/oss_scripts/llama/TARGETS @@ -34,6 +34,16 @@ python_library( ], ) +python_library( + name = "range_setting_pt2e", + srcs = [ + "range_setting_pt2e.py", + ], + deps = [ + "//caffe2:torch", + ], +) + python_binary( name = "llama", main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main", @@ -42,6 +52,7 @@ python_binary( ], deps = [ ":llama_lib", + "//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e", ], ) @@ -55,6 +66,7 @@ python_binary( deps = [ ":llama_lib", "//executorch/examples/models/llama:eval_library", + "//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e", "fbsource//third-party/pypi/lm-eval:lm-eval", ], ) diff --git a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py index 4e092b71892..8d13ed43b7b 100644 --- a/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py +++ b/examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py @@ -46,14 +46,18 @@ LlamaModel, ModelArgs, ) - -from executorch.examples.qualcomm.utils import make_quantizer +from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import ( + reverse_quantize_module_swap, + WrappedLlamaModel, + compute_scales, + set_scales, + make_custom_quantizer, +) from lm_eval.evaluator import simple_evaluate from pytorch_tokenizers import get_tokenizer -from torchao.quantization.pt2e import MinMaxObserver from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import QuantizationSpec @@ -87,7 +91,6 @@ def forward( ) return self.model.forward(tokens, self.atten_mask) - def add_mse_weight_observer(quant_dtype, quantizer): weight_dtype = ( torch.int4 @@ -118,21 +121,14 @@ def add_mse_weight_observer(quant_dtype, quantizer): def gen_eval_wrapper(model_name, args): tokenizer = get_tokenizer(args.tokenizer_path) with open(args.params) as f: - kv_config = ModelArgs(**json.load(f)) + prefill_config = ModelArgs(**json.load(f)) # TODO: support batch inputs if necessary - kv_config.max_batch_size = 1 - kv_config.max_seq_len = args.max_seq_length - kv_config.use_kv_cache = True - - prefill_config = copy.copy(kv_config) + prefill_config.max_batch_size = 1 prefill_config.max_seq_len = args.max_seq_length - prefill_config.use_kv_cache = ( - False if args.max_seq_length == args.prefill_ar_len else True - ) - config = prefill_config + prefill_config.use_kv_cache = False use_i64_token = args.embedding_quantize is not None model = LlamaModel( - config, + prefill_config, ar_len=args.prefill_ar_len, output_new_cache_only=True, output_cache=False, @@ -173,20 +169,32 @@ def permute(w, heads): if "model" in state_dict: state_dict = state_dict["model"] + tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) + tokens = tokens.to(device=args.device) + atten_mask = atten_mask.to(device=args.device) + atten_mask = atten_mask.to(dtype=torch.float) + inputs = (tokens, atten_mask) + + model = model.to(dtype=torch.float) + model = model.to(device=args.device) + + scales_state_dict = dict() + if args.range_setting == "mse_with_act_loss": + wrapped_model = WrappedLlamaModel( + model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device + ) + scales_state_dict = compute_scales(wrapped_model, tokens, 1600) # want to use different tokens for calibration! + torch.save(scales_state_dict, "scales_state_dict.pth") + logging.info("Saved scales to scales_state_dict.pth!") + reverse_quantize_module_swap(wrapped_model) + for layer in model.layers: if getattr(layer.attention, "prepare_sha", None): layer.attention.prepare_sha() if getattr(layer.feed_forward, "prepare_feedfoward_conv", None): layer.feed_forward.prepare_feedfoward_conv() - model.to(dtype=torch.float) - model.to(device=args.device) - - tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) - tokens = tokens.to(device=args.device) - atten_mask = atten_mask.to(device=args.device) - atten_mask = atten_mask.to(dtype=torch.float) - inputs = (tokens, atten_mask) + model = model.to(dtype=torch.float) if args.embedding_quantize: model = get_quant_embedding_transform( @@ -195,7 +203,7 @@ def permute(w, heads): model = convert_linear_to_conv2d(model) - if args.ptq: + if args.ptq is not None: quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") custom_annotations = (annotate_matmul_16a8w,) @@ -203,27 +211,20 @@ def permute(w, heads): custom_annotations = custom_annotations + ( annotate_linear_16a8w_in_affine_layer, ) - quantizer = make_quantizer( - quant_dtype=quant_dtype, - per_channel_conv=True, - per_channel_linear=True, - act_observer=MinMaxObserver, - ) - quantizer.add_custom_quant_annotations(custom_annotations) - if args.range_setting == "mse_weight": - add_mse_weight_observer(quant_dtype, quantizer) + quantizer = make_custom_quantizer(quant_dtype, args.range_setting, custom_annotations, args.quant_linear_only) with torch.no_grad(): + logging.info("Starting export...") model = torch.export.export(model, inputs, strict=True).module() if quant_dtype == QuantDtype.use_16a4w_block: conv_nodes = [n for n in model.graph.nodes if "conv" in n.name] block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes} quantizer.set_block_size_map(block_size_map) - + logging.info("Finished export, adding observers (prepare_pt2e)...") model = prepare_pt2e(model, quantizer) - logging.info("Quantizing the model...") + logging.info("Observers added, starting calibration...") calibrate( inputs, @@ -236,7 +237,24 @@ def permute(w, heads): use_i64_token=use_i64_token, ) + if args.range_setting == "mse_with_act_loss": + # scales_state_dict = torch.load("scales_state_dict.pth") + set_scales(model, scales_state_dict) + + logging.info("Quantizing the model...") model = convert_pt2e(model) + logging.info("Quantization complete! Here is some sample generated text:") + + calibrate( + inputs, + "Could you tell me about Facebook?", + model, + tokenizer=tokenizer, + ar_len=args.prefill_ar_len, + max_seq_len=args.max_seq_len, + kv_updater=None, + use_i64_token=use_i64_token, + ) model = WrappedLlamaModel( model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device @@ -248,7 +266,7 @@ def permute(w, heads): max_seq_length=args.calibration_seq_length, use_kv_cache=args.use_kv_cache, generate_full_logits=args.generate_full_logits, - enable_dynamic_shape=args.enable_dynamic_shape, + enable_dynamic_shape=False, ) @@ -271,7 +289,7 @@ def eval_llama( model=eval_wrapper, tasks=args.tasks, num_fewshot=args.num_fewshot, - limit=args.limit, + limit=args.fraction, ) for task, res in eval_results["results"].items(): @@ -291,13 +309,18 @@ def main() -> None: ) parser.add_argument( "--range_setting", - help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations", + help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax", type=str, ) parser.add_argument( - "--limit", - help="the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples", - type=str, + "--fraction", + help="the fraction of examples per task (only use this for testing)", + type=float, + ) + parser.add_argument( + "--quant_linear_only", + help="if you select this option we quantize linear layers only. If ptq arg not specified then defaults to 16a4w", + action='store_true', ) args = parser.parse_args() diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index db533986119..42a515bd8ae 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -63,6 +63,14 @@ LlamaModel, ModelArgs, ) +from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import ( + reverse_quantize_module_swap, + WrappedLlamaModel, + compute_scales, + set_scales, + make_custom_quantizer +) + from executorch.examples.qualcomm.utils import ( make_output_dir, make_quantizer, @@ -380,15 +388,9 @@ def _tag_ios(self, node, fixed_point_type): return quant_io_type - def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): + def quantize(self, quant_dtype, args, tokenizer, custom_annotations=(), scales_state_dict=None): self.quant_dtype = quant_dtype - quantizer = make_quantizer( - quant_dtype=quant_dtype, - per_channel_conv=True, - per_channel_linear=True, - act_observer=MinMaxObserver, - ) - quantizer.add_custom_quant_annotations(custom_annotations) + quantizer = make_custom_quantizer(quant_dtype, args.range_setting, custom_annotations) self.has_quant_io = True fx_graph_module = None @@ -408,6 +410,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) logging.info("Quantizing the model...") + calibrate( self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), args.prompt[0], @@ -419,6 +422,9 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): use_i64_token=args.embedding_quantize is not None, ) + if scales_state_dict: + set_scales(fx_graph_module, scales_state_dict) + self.llama_graph_module = convert_pt2e(fx_graph_module) def lowering_modules( @@ -597,6 +603,27 @@ def permute(w, heads): end_load_ts = time.time() logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}") + scales_state_dict = dict() + if args.range_setting == "mse_with_act_loss": + try: + scales_state_dict = torch.load("scales_state_dict.pth", map_location=torch.device('cpu')) + logging.info("Loaded scales_state_dict from file") + except: + logging.info("Computing scales using activation loss range setting") + model = llama_instance_list[1] + model.to(torch.float) + ar_len = model.ar_len + model.ar_len = model.max_seq_len + tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) + atten_mask.to(torch.float) + print(atten_mask.shape) + wrapped_model = WrappedLlamaModel( + model, atten_mask, model.use_kv_cache, args.max_seq_len, args.device + ) + scales_state_dict = compute_scales(wrapped_model, tokens, 1600) # want to use different tokens for calibration! + reverse_quantize_module_swap(wrapped_model) + model.ar_len = ar_len + for llama_instance in llama_instance_list: for layer in llama_instance.layers: if getattr(layer.attention, "prepare_sha", None): @@ -658,6 +685,7 @@ def permute(w, heads): args=args, tokenizer=tokenizer, custom_annotations=custom_annotations, + scales_state_dict=scales_state_dict, ) # If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later if i == 0 and args.model_mode in ["hybrid", "lookahead"]: @@ -668,12 +696,12 @@ def permute(w, heads): kv_quant_attrs[output_indices] = output.args[1:] output_indices += 1 break - custom_annotations = custom_annotations + ( - partial( - annotate_prefill_kv_output, - kv_quant_attrs=kv_quant_attrs, - ), - ) + # custom_annotations = custom_annotations + ( + # partial( + # annotate_prefill_kv_output, + # kv_quant_attrs=kv_quant_attrs, + # ), + # ) # temporarily remove annotate_prefill_kv_output llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ "get_quant_io_dtype_fn" @@ -1062,6 +1090,12 @@ def _build_parser(): type=int, ) + parser.add_argument( + "--range_setting", + help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax", + type=str, + ) + parser.add_argument("-v", "--verbose", action="store_true") return parser diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index f7893792e00..4de5c73a2c8 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -25,8 +25,10 @@ def apply_rotary_emb_single( x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] # broadcast for batch_prefill mode input x if x.dim() == 4: - freqs_cos = freqs_cos[None, None, :, :] - freqs_sin = freqs_sin[None, None, :, :] + # freqs_cos = freqs_cos[None, None, :, :] + # freqs_sin = freqs_sin[None, None, :, :] + freqs_cos = freqs_cos[None, :, None, :] + freqs_sin = freqs_sin[None, :, None, :] x_out_r = x_r * freqs_cos - x_i * freqs_sin x_out_i = x_r * freqs_sin + x_i * freqs_cos diff --git a/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py new file mode 100644 index 00000000000..d876d08e4a4 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama/range_setting_pt2e.py @@ -0,0 +1,296 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +""" +The goal of this is to allow range setting methods from TorchAO (formerly Quanty) +to be incorporated into the PT2E flow. + +We implement the two main range setting methods: +1) MSE weight range setting +2) Activation loss weight range setting + +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from executorch.backends.qualcomm.quantizer.annotators import OP_ANNOTATOR +from executorch.backends.qualcomm.quantizer.observers.per_channel_param_observer import ( + PerChannelParamObserver, +) + +from executorch.backends.qualcomm.quantizer.qconfig import ( + _derived_bias_quant_spec, + QuantizationConfig, +) +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype + +from executorch.examples.qualcomm.utils import make_quantizer + +from torchao.prototype.quantization.module_swap import ( + QuantizationRecipe, + quantize_module_swap, + QuantizedLinear, +) +from torchao.prototype.quantization.module_swap.module_swap import ( + get_layer_parent_by_name, +) +from torchao.prototype.quantization.module_swap.quantized_modules import ( + QuantizedEmbedding, +) +from torchao.prototype.quantization.module_swap.range_setting_methods import ( + set_weight_range_activation_loss, +) + +from torchao.quantization.pt2e import ( + MinMaxObserver, + PerChannelMinMaxObserver, + UniformQuantizationObserverBase, +) +from torchao.quantization.pt2e.quantizer import QuantizationSpec + + +class WrappedLlamaModel(nn.Module): + def __init__( + self, model, atten_mask, use_kv_cache=False, max_seq_len=512, device="cuda" + ): + super(WrappedLlamaModel, self).__init__() + self.model = model + self.max_seq_len = max_seq_len + self.use_kv_cache = use_kv_cache + self.device = device + self.atten_mask = atten_mask + + def forward( + self, + tokens: torch.Tensor, + *args, + ): + # Pad input if necessary, since LlamaModel requires static shape + if tokens.shape[1] != self.max_seq_len: + tokens = torch.nn.functional.pad( + tokens, (0, self.max_seq_len - tokens.shape[1]) + ) + return self.model.forward(tokens, self.atten_mask) + + +class PerChannelMSEObserver(PerChannelParamObserver): + + def forward(self, x_orig): + # since params are static, one calibration is enough + if not self.calibrated: + x = x_orig.detach().to(self.min_val.dtype) + self.min_val, self.max_val = self.line_search(x) + self.calibrated = True + + return x_orig + + +class PerChannelFixedQParamsObserver(PerChannelMinMaxObserver): + r""" + Fixed scale that you set manually (for per channel quantization) + Symmetric quantization, so zero point is always zero + If scale not set, defaults to minmax + """ + + def __init__( + self, + ch_axis=0, + dtype=torch.quint8, + qscheme=torch.per_channel_symmetric, + quant_min=0, + quant_max=255, + is_dynamic=False, + **kwargs, + ): + super().__init__(ch_axis=ch_axis, dtype=dtype, qscheme=qscheme, is_dynamic=is_dynamic, **kwargs) + self.quant_min = quant_min + self.quant_max = quant_max + + def set_scale(self, scale): + self.register_buffer("scale", scale.clone().detach()) + self.register_buffer("zero_point", torch.zeros_like(scale)) + + def calculate_qparams(self): + if hasattr(self, "scale") and hasattr(self, "zero_point"): + print("Using precomputed scale") + return self.scale, self.zero_point + print("Using minmax scale") + return self._calculate_qparams(self.min_val, self.max_val) + + +def reverse_quantize_module_swap(model: nn.Module) -> nn.Module: + model = reverse_replace_all_linear_with_quantized(model) + model = reverse_replace_all_embedding_with_quantized(model) # if embedding_quantize was false, does nothing + return model + + +def reverse_replace_all_embedding_with_quantized( + model: nn.Module +) -> nn.Module: + for name, module in model.named_modules(): + if isinstance(module, QuantizedEmbedding): + embedding = nn.Embedding( + num_embeddings=module.num_embeddings, + embedding_dim=module.embedding_dim, + padding_idx=module.padding_idx, + max_norm=module.max_norm, + norm_type=module.norm_type, + scale_grad_by_freq=module.scale_grad_by_freq, + sparse=module.sparse, + _weight=module.weight, + ) + attribute_name = name.rsplit(".", 1)[-1] + parent_of_module = get_layer_parent_by_name(model, name) + setattr(parent_of_module, attribute_name, embedding) + + # logger.info(f"replaced {name} with original embedding") + return model + + +def reverse_replace_all_linear_with_quantized( + model: nn.Module, +) -> nn.Module: + for name, module in model.named_modules(): + if isinstance(module, QuantizedLinear): + linear = nn.Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + ) + linear.weight = module.weight + linear.bias = module.bias + + attribute_name = name.rsplit(".", 1)[-1] + parent_of_module = get_layer_parent_by_name(model, name) + setattr(parent_of_module, attribute_name, linear) + + # logger.info(f"replaced {name} with originallinear") + return model + + +def compute_scales(model, data, num_points=100): + print("Computing scales!") + recipe = QuantizationRecipe( + weight_bits=4, # TODO: should be based on dtype! + weight_quantization=True, + dynamic_weights=False, + weight_group_size="per_channel", + activation_bits=16, # same as above + activation_quantization=True, + activation_group_size="per_tensor", + input_quantization=True, + output_quantization=True, + dynamic_activations=False, + ) + + quantized_model = quantize_module_swap(model, recipe) + + set_weight_range_activation_loss(quantized_model, data, 1, num_points) # batch_size = 1 for us + scales_state_dict = dict() + for name, module in quantized_model.named_modules(): + if isinstance(module, QuantizedLinear): + scales_state_dict[name] = module.weight_scale.clone().detach() + + return scales_state_dict + + +def make_custom_quantizer(quant_dtype, range_setting=None, custom_annotations=(), linear_only=False): + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + ) + if range_setting in ("mse_weight_only", "mse_with_act_loss", "na"): + if range_setting == "na": + observer = PerChannelMinMaxObserver + elif range_setting == "mse_weight_only": + observer = PerChannelMSEObserver.with_args(**{"steps": 200, "use_mse": True}) + else: + observer = PerChannelFixedQParamsObserver.with_args(**{"eps": 2**-12}) + weight_dtype = ( + torch.int4 + if quant_dtype in (QuantDtype.use_16a4w, QuantDtype.use_16a4w_block) + else torch.int8 + ) + per_channel_q_config = quantizer.default_quant_config.quant_config + weight_qspec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == torch.int4 else weight_dtype, + quant_min=( + -7 + if weight_dtype == torch.int4 + else torch.iinfo(weight_dtype).min + 1 + ), + quant_max=( + 7 if weight_dtype == torch.int4 else torch.iinfo(weight_dtype).max + ), + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=observer, + ) + quantizer.default_quant_config.per_channel_quant_config = ( + QuantizationConfig( + input_activation=per_channel_q_config.input_activation, + output_activation=per_channel_q_config.output_activation, + weight=weight_qspec, + bias=_derived_bias_quant_spec, + ) + ) + if linear_only: + all_keys = set(OP_ANNOTATOR.keys()) + conv_keys = {op for op in all_keys if op.__name__ in ('conv1d.default', 'conv2d.default', 'conv_transpose2d.input', 'linear.default')} + quantizer.add_discard_ops(all_keys.difference(conv_keys)) + else: + quantizer.add_custom_quant_annotations(custom_annotations) + return quantizer + + +def set_scales(prepared_model, scales_state_dict): + print("Setting the scales of the manual observers!") + num_heads = 32 # should set by argument + head_dim = 2048 // num_heads + for node in prepared_model.graph.nodes: + if type(node.target) == str: + l = node.target.split(".") + if len(l) > 3 and l[-3] in ("wq_sha", "wk_sha", "wv_sha"): + shorter = l[-3][:2] + key = ".".join(["model"] + l[:-3] + [shorter]) + observer_name = str(list(node.users.keys())[0]) + observer = getattr(prepared_model, observer_name) + i = int(l[-2]) + try: + observer.set_scale( + scales_state_dict[key][head_dim * i : head_dim * (i + 1), :] + ) + print("Set scale for", key) + except Exception as e: + print("Failed to set scale for ", key) + print({e}) + elif len(l) > 1 and l[-2] in ("wo_sha", "w1_conv", "w2_conv", "w3_conv"): + shorter = l[-2][:2] + key = ".".join(["model"] + l[:-2] + [shorter]) + observer_name = str(list(node.users.keys())[0]) + observer = getattr(prepared_model, observer_name) + try: + observer.set_scale(scales_state_dict[key]) + print("Set scale for", key) + except Exception as e: + print("Failed to set scale for ", key) + print({e}) + # elif len(l) > 0 and l[0] == 'output': + # key = "model.output" + # observer_name = str(list(node.users.keys())[0]) + # observer = getattr(prepared_model, observer_name) + # #if type(observer.target) == str and observer.target[:10] == "activation": + # try: + # observer.set_scale(scales_state_dict[key]) + # print("Set scale for", key) + # except: + # print("Failed to set scale for ", key) + + print("done")