Skip to content

Implemented range setting in QNN llama flow #12377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/qualcomm/oss_scripts/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ python_library(
],
)

python_library(
name = "range_setting_pt2e",
srcs = [
"range_setting_pt2e.py",
],
deps = [
"//caffe2:torch",
],
)

python_binary(
name = "llama",
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
Expand All @@ -42,6 +52,7 @@ python_binary(
],
deps = [
":llama_lib",
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
],
)

Expand All @@ -55,6 +66,7 @@ python_binary(
deps = [
":llama_lib",
"//executorch/examples/models/llama:eval_library",
"//executorch/examples/qualcomm/oss_scripts/llama:range_setting_pt2e",
"fbsource//third-party/pypi/lm-eval:lm-eval",
],
)
Expand Down
105 changes: 64 additions & 41 deletions examples/qualcomm/oss_scripts/llama/eval_llama_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,18 @@
LlamaModel,
ModelArgs,
)

from executorch.examples.qualcomm.utils import make_quantizer
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import (
reverse_quantize_module_swap,
WrappedLlamaModel,
compute_scales,
set_scales,
make_custom_quantizer,
)

from lm_eval.evaluator import simple_evaluate

from pytorch_tokenizers import get_tokenizer

from torchao.quantization.pt2e import MinMaxObserver
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
from torchao.quantization.pt2e.quantizer import QuantizationSpec

Expand Down Expand Up @@ -87,7 +91,6 @@ def forward(
)
return self.model.forward(tokens, self.atten_mask)


def add_mse_weight_observer(quant_dtype, quantizer):
weight_dtype = (
torch.int4
Expand Down Expand Up @@ -118,21 +121,14 @@ def add_mse_weight_observer(quant_dtype, quantizer):
def gen_eval_wrapper(model_name, args):
tokenizer = get_tokenizer(args.tokenizer_path)
with open(args.params) as f:
kv_config = ModelArgs(**json.load(f))
prefill_config = ModelArgs(**json.load(f))
# TODO: support batch inputs if necessary
kv_config.max_batch_size = 1
kv_config.max_seq_len = args.max_seq_length
kv_config.use_kv_cache = True

prefill_config = copy.copy(kv_config)
prefill_config.max_batch_size = 1
prefill_config.max_seq_len = args.max_seq_length
prefill_config.use_kv_cache = (
False if args.max_seq_length == args.prefill_ar_len else True
)
config = prefill_config
prefill_config.use_kv_cache = False
use_i64_token = args.embedding_quantize is not None
model = LlamaModel(
config,
prefill_config,
ar_len=args.prefill_ar_len,
output_new_cache_only=True,
output_cache=False,
Expand Down Expand Up @@ -173,20 +169,32 @@ def permute(w, heads):
if "model" in state_dict:
state_dict = state_dict["model"]

tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
tokens = tokens.to(device=args.device)
atten_mask = atten_mask.to(device=args.device)
atten_mask = atten_mask.to(dtype=torch.float)
inputs = (tokens, atten_mask)

model = model.to(dtype=torch.float)
model = model.to(device=args.device)

scales_state_dict = dict()
if args.range_setting == "mse_with_act_loss":
wrapped_model = WrappedLlamaModel(
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
)
scales_state_dict = compute_scales(wrapped_model, tokens, 1600) # want to use different tokens for calibration!
torch.save(scales_state_dict, "scales_state_dict.pth")
logging.info("Saved scales to scales_state_dict.pth!")
reverse_quantize_module_swap(wrapped_model)

for layer in model.layers:
if getattr(layer.attention, "prepare_sha", None):
layer.attention.prepare_sha()
if getattr(layer.feed_forward, "prepare_feedfoward_conv", None):
layer.feed_forward.prepare_feedfoward_conv()

model.to(dtype=torch.float)
model.to(device=args.device)

tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
tokens = tokens.to(device=args.device)
atten_mask = atten_mask.to(device=args.device)
atten_mask = atten_mask.to(dtype=torch.float)
inputs = (tokens, atten_mask)
model = model.to(dtype=torch.float)

if args.embedding_quantize:
model = get_quant_embedding_transform(
Expand All @@ -195,35 +203,28 @@ def permute(w, heads):

model = convert_linear_to_conv2d(model)

if args.ptq:
if args.ptq is not None:
quant_dtype = getattr(QuantDtype, f"use_{args.ptq}")

custom_annotations = (annotate_matmul_16a8w,)
if args.llama_model == "stories110m":
custom_annotations = custom_annotations + (
annotate_linear_16a8w_in_affine_layer,
)
quantizer = make_quantizer(
quant_dtype=quant_dtype,
per_channel_conv=True,
per_channel_linear=True,
act_observer=MinMaxObserver,
)
quantizer.add_custom_quant_annotations(custom_annotations)

if args.range_setting == "mse_weight":
add_mse_weight_observer(quant_dtype, quantizer)
quantizer = make_custom_quantizer(quant_dtype, args.range_setting, custom_annotations, args.quant_linear_only)

with torch.no_grad():
logging.info("Starting export...")
model = torch.export.export(model, inputs, strict=True).module()
if quant_dtype == QuantDtype.use_16a4w_block:
conv_nodes = [n for n in model.graph.nodes if "conv" in n.name]
block_size_map = {n.name: (1, 64, 1, 1) for n in conv_nodes}
quantizer.set_block_size_map(block_size_map)

logging.info("Finished export, adding observers (prepare_pt2e)...")
model = prepare_pt2e(model, quantizer)

logging.info("Quantizing the model...")
logging.info("Observers added, starting calibration...")

calibrate(
inputs,
Expand All @@ -236,7 +237,24 @@ def permute(w, heads):
use_i64_token=use_i64_token,
)

if args.range_setting == "mse_with_act_loss":
# scales_state_dict = torch.load("scales_state_dict.pth")
set_scales(model, scales_state_dict)

logging.info("Quantizing the model...")
model = convert_pt2e(model)
logging.info("Quantization complete! Here is some sample generated text:")

calibrate(
inputs,
"Could you tell me about Facebook?",
model,
tokenizer=tokenizer,
ar_len=args.prefill_ar_len,
max_seq_len=args.max_seq_len,
kv_updater=None,
use_i64_token=use_i64_token,
)

model = WrappedLlamaModel(
model, atten_mask, args.use_kv_cache, args.max_seq_length, args.device
Expand All @@ -248,7 +266,7 @@ def permute(w, heads):
max_seq_length=args.calibration_seq_length,
use_kv_cache=args.use_kv_cache,
generate_full_logits=args.generate_full_logits,
enable_dynamic_shape=args.enable_dynamic_shape,
enable_dynamic_shape=False,
)


Expand All @@ -271,7 +289,7 @@ def eval_llama(
model=eval_wrapper,
tasks=args.tasks,
num_fewshot=args.num_fewshot,
limit=args.limit,
limit=args.fraction,
)

for task, res in eval_results["results"].items():
Expand All @@ -291,13 +309,18 @@ def main() -> None:
)
parser.add_argument(
"--range_setting",
help="Choose which range setting method (e.g. mse_weight). If not specified, will do minmax for weights and activations",
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax",
type=str,
)
parser.add_argument(
"--limit",
help="the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples",
type=str,
"--fraction",
help="the fraction of examples per task (only use this for testing)",
type=float,
)
parser.add_argument(
"--quant_linear_only",
help="if you select this option we quantize linear layers only. If ptq arg not specified then defaults to 16a4w",
action='store_true',
)

args = parser.parse_args()
Expand Down
62 changes: 48 additions & 14 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@
LlamaModel,
ModelArgs,
)
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import (
reverse_quantize_module_swap,
WrappedLlamaModel,
compute_scales,
set_scales,
make_custom_quantizer
)

from executorch.examples.qualcomm.utils import (
make_output_dir,
make_quantizer,
Expand Down Expand Up @@ -380,15 +388,9 @@ def _tag_ios(self, node, fixed_point_type):

return quant_io_type

def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
def quantize(self, quant_dtype, args, tokenizer, custom_annotations=(), scales_state_dict=None):
self.quant_dtype = quant_dtype
quantizer = make_quantizer(
quant_dtype=quant_dtype,
per_channel_conv=True,
per_channel_linear=True,
act_observer=MinMaxObserver,
)
quantizer.add_custom_quant_annotations(custom_annotations)
quantizer = make_custom_quantizer(quant_dtype, args.range_setting, custom_annotations)

self.has_quant_io = True
fx_graph_module = None
Expand All @@ -408,6 +410,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)

logging.info("Quantizing the model...")

calibrate(
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
args.prompt[0],
Expand All @@ -419,6 +422,9 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
use_i64_token=args.embedding_quantize is not None,
)

if scales_state_dict:
set_scales(fx_graph_module, scales_state_dict)

self.llama_graph_module = convert_pt2e(fx_graph_module)

def lowering_modules(
Expand Down Expand Up @@ -597,6 +603,27 @@ def permute(w, heads):
end_load_ts = time.time()
logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}")

scales_state_dict = dict()
if args.range_setting == "mse_with_act_loss":
try:
scales_state_dict = torch.load("scales_state_dict.pth", map_location=torch.device('cpu'))
logging.info("Loaded scales_state_dict from file")
except:
logging.info("Computing scales using activation loss range setting")
model = llama_instance_list[1]
model.to(torch.float)
ar_len = model.ar_len
model.ar_len = model.max_seq_len
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False)
atten_mask.to(torch.float)
print(atten_mask.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing debugging line

wrapped_model = WrappedLlamaModel(
model, atten_mask, model.use_kv_cache, args.max_seq_len, args.device
)
scales_state_dict = compute_scales(wrapped_model, tokens, 1600) # want to use different tokens for calibration!
reverse_quantize_module_swap(wrapped_model)
model.ar_len = ar_len

for llama_instance in llama_instance_list:
for layer in llama_instance.layers:
if getattr(layer.attention, "prepare_sha", None):
Expand Down Expand Up @@ -658,6 +685,7 @@ def permute(w, heads):
args=args,
tokenizer=tokenizer,
custom_annotations=custom_annotations,
scales_state_dict=scales_state_dict,
)
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later
if i == 0 and args.model_mode in ["hybrid", "lookahead"]:
Expand All @@ -668,12 +696,12 @@ def permute(w, heads):
kv_quant_attrs[output_indices] = output.args[1:]
output_indices += 1
break
custom_annotations = custom_annotations + (
partial(
annotate_prefill_kv_output,
kv_quant_attrs=kv_quant_attrs,
),
)
# custom_annotations = custom_annotations + (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I need to have a separate PR for this.

# partial(
# annotate_prefill_kv_output,
# kv_quant_attrs=kv_quant_attrs,
# ),
# ) # temporarily remove annotate_prefill_kv_output
llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True
llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][
"get_quant_io_dtype_fn"
Expand Down Expand Up @@ -1062,6 +1090,12 @@ def _build_parser():
type=int,
)

parser.add_argument(
"--range_setting",
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax",
type=str,
)

parser.add_argument("-v", "--verbose", action="store_true")

return parser
Expand Down
6 changes: 4 additions & 2 deletions examples/qualcomm/oss_scripts/llama/model/static_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ def apply_rotary_emb_single(
x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
# broadcast for batch_prefill mode input x
if x.dim() == 4:
freqs_cos = freqs_cos[None, None, :, :]
freqs_sin = freqs_sin[None, None, :, :]
# freqs_cos = freqs_cos[None, None, :, :]
# freqs_sin = freqs_sin[None, None, :, :]
freqs_cos = freqs_cos[None, :, None, :]
freqs_sin = freqs_sin[None, :, None, :]
x_out_r = x_r * freqs_cos - x_i * freqs_sin
x_out_i = x_r * freqs_sin + x_i * freqs_cos

Expand Down
Loading
Loading