-
Notifications
You must be signed in to change notification settings - Fork 622
Implemented range setting in QNN llama flow #12377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rohansjoshi
wants to merge
1
commit into
pytorch:main
Choose a base branch
from
rohansjoshi:export-D78127727
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,14 @@ | |
LlamaModel, | ||
ModelArgs, | ||
) | ||
from executorch.examples.qualcomm.oss_scripts.llama.range_setting_pt2e import ( | ||
reverse_quantize_module_swap, | ||
WrappedLlamaModel, | ||
compute_scales, | ||
set_scales, | ||
make_custom_quantizer | ||
) | ||
|
||
from executorch.examples.qualcomm.utils import ( | ||
make_output_dir, | ||
make_quantizer, | ||
|
@@ -380,15 +388,9 @@ def _tag_ios(self, node, fixed_point_type): | |
|
||
return quant_io_type | ||
|
||
def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): | ||
def quantize(self, quant_dtype, args, tokenizer, custom_annotations=(), scales_state_dict=None): | ||
self.quant_dtype = quant_dtype | ||
quantizer = make_quantizer( | ||
quant_dtype=quant_dtype, | ||
per_channel_conv=True, | ||
per_channel_linear=True, | ||
act_observer=MinMaxObserver, | ||
) | ||
quantizer.add_custom_quant_annotations(custom_annotations) | ||
quantizer = make_custom_quantizer(quant_dtype, args.range_setting, custom_annotations) | ||
|
||
self.has_quant_io = True | ||
fx_graph_module = None | ||
|
@@ -408,6 +410,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): | |
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) | ||
|
||
logging.info("Quantizing the model...") | ||
|
||
calibrate( | ||
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), | ||
args.prompt[0], | ||
|
@@ -419,6 +422,9 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()): | |
use_i64_token=args.embedding_quantize is not None, | ||
) | ||
|
||
if scales_state_dict: | ||
set_scales(fx_graph_module, scales_state_dict) | ||
|
||
self.llama_graph_module = convert_pt2e(fx_graph_module) | ||
|
||
def lowering_modules( | ||
|
@@ -597,6 +603,27 @@ def permute(w, heads): | |
end_load_ts = time.time() | ||
logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}") | ||
|
||
scales_state_dict = dict() | ||
if args.range_setting == "mse_with_act_loss": | ||
try: | ||
scales_state_dict = torch.load("scales_state_dict.pth", map_location=torch.device('cpu')) | ||
logging.info("Loaded scales_state_dict from file") | ||
except: | ||
logging.info("Computing scales using activation loss range setting") | ||
model = llama_instance_list[1] | ||
model.to(torch.float) | ||
ar_len = model.ar_len | ||
model.ar_len = model.max_seq_len | ||
tokens, atten_mask = model.get_example_inputs(use_kv_cache=False) | ||
atten_mask.to(torch.float) | ||
print(atten_mask.shape) | ||
wrapped_model = WrappedLlamaModel( | ||
model, atten_mask, model.use_kv_cache, args.max_seq_len, args.device | ||
) | ||
scales_state_dict = compute_scales(wrapped_model, tokens, 1600) # want to use different tokens for calibration! | ||
reverse_quantize_module_swap(wrapped_model) | ||
model.ar_len = ar_len | ||
|
||
for llama_instance in llama_instance_list: | ||
for layer in llama_instance.layers: | ||
if getattr(layer.attention, "prepare_sha", None): | ||
|
@@ -658,6 +685,7 @@ def permute(w, heads): | |
args=args, | ||
tokenizer=tokenizer, | ||
custom_annotations=custom_annotations, | ||
scales_state_dict=scales_state_dict, | ||
) | ||
# If hybrid and lookahead mode, we store kv output quant_attrs and apply to prefill output quant_attrs later | ||
if i == 0 and args.model_mode in ["hybrid", "lookahead"]: | ||
|
@@ -668,12 +696,12 @@ def permute(w, heads): | |
kv_quant_attrs[output_indices] = output.args[1:] | ||
output_indices += 1 | ||
break | ||
custom_annotations = custom_annotations + ( | ||
partial( | ||
annotate_prefill_kv_output, | ||
kv_quant_attrs=kv_quant_attrs, | ||
), | ||
) | ||
# custom_annotations = custom_annotations + ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I need to have a separate PR for this. |
||
# partial( | ||
# annotate_prefill_kv_output, | ||
# kv_quant_attrs=kv_quant_attrs, | ||
# ), | ||
# ) # temporarily remove annotate_prefill_kv_output | ||
llama_instance.passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True | ||
llama_instance.passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ | ||
"get_quant_io_dtype_fn" | ||
|
@@ -1062,6 +1090,12 @@ def _build_parser(): | |
type=int, | ||
) | ||
|
||
parser.add_argument( | ||
"--range_setting", | ||
help="Choose which range setting method for weight quantization (e.g. mse_weight_only or mse_with_act_loss). If not specified, defaults to minmax", | ||
type=str, | ||
) | ||
|
||
parser.add_argument("-v", "--verbose", action="store_true") | ||
|
||
return parser | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing debugging line