Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/openvino/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .nncf_compression import use_nncf_compression

__all__ = ["use_nncf_compression"]
79 changes: 79 additions & 0 deletions backends/openvino/quantization/nncf_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Intel Corporation
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file found in the
# LICENSE file in the root directory of this source tree.

# mypy: disable-error-code=import-not-found

import torch

try:
import nncf # type: ignore[import-untyped]
from pytorch_tokenizers import get_tokenizer # type: ignore[import-untyped]
except ImportError:
raise ImportError("Please install nncf via backends/openvino/requirements.txt")


def get_calibration_data(
module: torch.fx.GraphModule, tokenizer, prompts: str, max_len: int
):
# TODO: change criteria & support batch inputs if necessary
pos = torch.tensor(0, dtype=torch.int64)
token_list = tokenizer.encode(prompts, bos=True, eos=False)

with torch.no_grad():
while token_list[-1] != tokenizer.eos_id and pos < max_len:
logits = module(
torch.full((1, 1), token_list[pos]),
{"input_pos": torch.tensor((pos,))},
)
pos += 1
if pos >= len(token_list):
token_list.append(torch.argmax(logits[:], dim=-1).item())
token_list = [
(
pos,
token,
)
for pos, token in enumerate(token_list)
]
return token_list


def transform_fn(token_pos_map: tuple[int, str]):
# tokenized_text = tokenizer.encode(prompts, bos=False, eos=False)
inputs = ()
inputs = (
torch.tensor(token_pos_map[1]).unsqueeze(0).unsqueeze(0),
{"input_pos": torch.tensor([token_pos_map[0]])},
)

return inputs


def apply_nncf_data_aware_compression(
builder_exported, quantizers, awq: bool, scale_estimation: bool
):
tokenizer = get_tokenizer(builder_exported.tokenizer_path)

builder_exported.calibration_data = get_calibration_data(
builder_exported.pre_autograd_graph_module,
tokenizer,
builder_exported.calibration_data,
builder_exported.max_seq_len,
)

builder_exported.pre_autograd_graph_module = (
nncf.experimental.torch.fx.compress_pt2e(
builder_exported.pre_autograd_graph_module,
quantizer=quantizers[0],
dataset=nncf.Dataset(
builder_exported.calibration_data,
transform_func=transform_fn,
),
awq=awq,
scale_estimation=scale_estimation,
)
)
return builder_exported
51 changes: 29 additions & 22 deletions backends/openvino/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
INT4WeightObserver,
INT8WeightObserver,
)
from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped]
from nncf.common.graph.graph import NNCFGraph, NNCFNode # type: ignore[import-untyped]
from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic
from nncf.quantization.algorithms.weight_compression.config import ( # type: ignore[import-untyped]
WeightCompressionParameters,
)
Expand Down Expand Up @@ -96,30 +97,25 @@ def __init__(
"""
self.mode = mode
if self.mode not in OpenVINOQuantizer.WEIGHTS_ONLY_COMPRESSION_MODES:
if mode == QuantizationMode.INT8_SYM:
preset = quantization.structs.QuantizationPreset.PERFORMANCE
model_type = None
elif mode == QuantizationMode.INT8_MIXED:
preset = quantization.structs.QuantizationPreset.MIXED
model_type = None
else:
preset = None
model_type = nncf.parameters.ModelType.TRANSFORMER
self._algo = (
nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization(
preset=preset, model_type=model_type, **kwargs
**kwargs
)
)
else:
weight_compression_configuration = get_weight_compression_configuration(
mode.value.replace(
"wo", ""
), # Mode value has to match NNCF CompressWeightsMode
**kwargs,
)
mode = mode.value.replace(
"wo", ""
) # Mode value has to match NNCF CompressWeightsMode
subset_size = 1 # Doesn't really matter in this case since it is data-free. Should just be +ve
self.weight_compression_configuration = (
get_weight_compression_configuration(
mode,
**kwargs,
)
)
_weight_compression_configuration = self.weight_compression_configuration
self._algo = nncf.quantization.algorithms.weight_compression.algorithm.WeightCompression(
subset_size=subset_size, **weight_compression_configuration
subset_size=subset_size, **_weight_compression_configuration
)

def set_ignored_scope(
Expand Down Expand Up @@ -157,6 +153,16 @@ def get_nncf_quantization_setup(
self._algo._set_backend_entity(model)
return self._algo.find_quantization_setup(model, nncf_graph)

def get_nncf_weight_compression_parameters(
self,
model: torch.fx.GraphModule,
nncf_graph: NNCFGraph,
) -> tuple[
list[WeightCompressionParameters], Optional[dict[str, WCTensorStatistic]]
]:
self._algo.set_backend_entity(model)
return self._algo.get_weight_compression_parameters(model, nncf_graph)

def _annotate_weight_compression(
self,
model: torch.fx.GraphModule,
Expand All @@ -176,12 +182,13 @@ def _annotate_weight_compression(
:param node_vs_torch_annotation: A mapping of FX nodes to quantization annotations.
:return: Updated mapping of FX nodes with weight compression annotations.
"""
self._algo.set_backend_entity(model)
all_wc_params, _ = self._algo.get_weight_compression_parameters(
all_wc_params, *_ = self.get_nncf_weight_compression_parameters(
model, nncf_graph
)

for wc_param in all_wc_params:
if not wc_param.compression_config:
continue
node_with_weight = wc_param.node_with_weight
target_node = nncf_fx.node_utils.get_graph_node_by_name(
graph, node_with_weight.node_name
Expand Down Expand Up @@ -384,7 +391,7 @@ def _get_edge_or_node(
"""
ip = qp.insertion_point
if qp.is_weight_quantization_point():
OpenVINOQuantizer._get_weight_edge(target_node, nncf_graph)
return OpenVINOQuantizer._get_weight_edge(target_node, nncf_graph)

if ip.input_port_id is None:
return target_node
Expand Down Expand Up @@ -588,4 +595,4 @@ def quantize_model(
smooth_quant=smooth_quant,
**kwargs,
)
return quantized_model
return quantized_model
2 changes: 1 addition & 1 deletion backends/openvino/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
transformers
git+https://github.com/openvinotoolkit/nncf@3d753ac#egg=nncf
git+https://github.com/openvinotoolkit/nncf@an/fx/compress_pt2e#egg=nncf
29 changes: 26 additions & 3 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Callable, List, Optional, Union

import torch
from executorch.backends.openvino.quantization import apply_nncf_data_aware_compression

from executorch.devtools.backend_debug import print_delegation_info
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
Expand Down Expand Up @@ -247,6 +248,19 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Path to the adapter_config.json file. Used if the model has trained LoRA adapters. Must provide adapter_checkpoint.",
)

parser.add_argument(
"--nncf_awq",
required=False,
action="store_true",
help="Whether to use AWQ from NNCF. Applicable only for the Openvino backend.",
)

parser.add_argument(
"--nncf_scale_estimation",
action="store_true",
help="Whether to use Scale Estimation algorithm from NNCF. Applicable only for the Openvino backend",
)

parser.add_argument(
"--use_qnn_sha",
action="store_true",
Expand Down Expand Up @@ -920,6 +934,8 @@ def _to_edge_and_lower_llama_openvino(
modelname,
quantizers,
additional_passes,
awq,
scale_estimation,
openvino_device: str = "CPU",
verbose: bool = False,
) -> LLMEdgeManager: # noqa: C901
Expand All @@ -933,9 +949,14 @@ def _to_edge_and_lower_llama_openvino(
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
partitioners
)
if awq or scale_estimation:
builder = apply_nncf_data_aware_compression(
builder_exported, quantizers, awq, scale_estimation
)
else:
builder = builder_exported.pt2e_quantize(quantizers)

builder = builder.to_edge_transform_and_lower(partitioners)

if verbose:
print_delegation_info(builder.edge_manager.exported_program().graph_module)
Expand Down Expand Up @@ -1172,6 +1193,8 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
modelname,
quantizers,
additional_passes,
awq=llm_config.backend.openvino.awq,
scale_estimation=llm_config.backend.openvino.scale_estimation,
openvino_device=llm_config.backend.openvino.device,
verbose=llm_config.debug.verbose,
)
Expand Down
11 changes: 8 additions & 3 deletions extension/llm/export/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,9 @@ class OpenvinoConfig:

enabled: bool = False
device: str = "CPU"
nncf_compression: bool = False
nncf_compression_group_size: int = 32
awq: bool = False
scale_estimation: bool = False


@dataclass
Expand Down Expand Up @@ -667,8 +668,12 @@ def from_args(cls, args: argparse.Namespace) -> "LlmConfig": # noqa: C901
llm_config.backend.openvino.enabled = args.openvino
if hasattr(args, "openvino_device"):
llm_config.backend.openvino.device = args.openvino_device
if hasattr(args, "nncf_compression"):
llm_config.backend.openvino.nncf_compression = args.nncf_compression
if hasattr(args, "nncf_awq"):
llm_config.backend.openvino.nncf_awq = args.nncf_awq
if hasattr(args, "nncf_scale_estimation"):
llm_config.backend.openvino.nncf_scale_estimation = (
args.nncf_scale_estimation
)
if hasattr(args, "group_size") and args.group_size:
llm_config.backend.openvino.nncf_compression_group_size = args.group_size

Expand Down