Skip to content

Commit 01e9845

Browse files
pulkitalPulkit Agrawalyifan_shen3
authored
Add CoreMLQuantizer in coremltools.optimize.torch to support PyTorch Export based quantization (#2162)
* Add support for PyTorch Export Quantizer * add quantization conversion test --------- Co-authored-by: Pulkit Agrawal <pulkital_agrawal@apple.com> Co-authored-by: yifan_shen3 <yifan_shen3@apple.com>
1 parent b416f36 commit 01e9845

File tree

6 files changed

+1880
-12
lines changed

6 files changed

+1880
-12
lines changed

coremltools/converters/mil/frontend/torch/test/test_executorch_quantization.py

Lines changed: 105 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,84 @@
33
# Use of this source code is governed by a BSD-3-clause license that can be
44
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
55

6+
import itertools
67
import pytest
8+
from typing import Tuple
9+
10+
from coremltools._deps import _HAS_EXECUTORCH
11+
12+
if not _HAS_EXECUTORCH:
13+
pytest.skip(allow_module_level=True, reason="executorch is required")
714

815
import torch
916
from torch._export import capture_pre_autograd_graph
10-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
17+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e
1118
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
1219
get_symmetric_quantization_config,
1320
XNNPACKQuantizer,
1421
)
1522

16-
import coremltools as ct
17-
from coremltools._deps import _HAS_EXECUTORCH
23+
_TORCH_VERSION = torch.__version__
24+
_EXPECTED_TORCH_VERSION = "2.2.0"
25+
if _TORCH_VERSION < _EXPECTED_TORCH_VERSION:
26+
pytest.skip(allow_module_level=True, reason=f"PyTorch {_EXPECTED_TORCH_VERSION} or higher is required")
1827

19-
if not _HAS_EXECUTORCH:
20-
pytest.skip(allow_module_level=True, reason="executorch is required")
28+
import coremltools as ct
29+
from coremltools.optimize.torch.quantization.quantization_config import (
30+
LinearQuantizerConfig,
31+
QuantizationScheme,
32+
)
33+
from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer
2134

2235
from .testing_utils import TorchBaseTest, TorchFrontend
2336

2437

2538
class TestExecutorchQuantization(TorchBaseTest):
26-
def test_conv_relu(self):
39+
@staticmethod
40+
def make_torch_quantized_graph(
41+
model,
42+
example_inputs: Tuple[torch.Tensor],
43+
quantizer_name: str,
44+
quantization_type: str,
45+
) -> torch.fx.GraphModule:
46+
assert quantizer_name in {"XNNPack", "CoreML"}
47+
assert quantization_type in {"PTQ", "QAT"}
48+
49+
pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_inputs)
50+
51+
if quantizer_name == "XNNPack":
52+
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
53+
elif quantizer_name == "CoreML":
54+
quantization_config = LinearQuantizerConfig.from_dict(
55+
{
56+
"global_config": {
57+
"quantization_scheme": QuantizationScheme.symmetric,
58+
"milestones": [0, 0, 10, 10],
59+
"activation_dtype": torch.quint8,
60+
"weight_dtype": torch.qint8,
61+
"weight_per_channel": True,
62+
}
63+
}
64+
)
65+
quantizer = CoreMLQuantizer(quantization_config)
66+
67+
if quantization_type == "PTQ":
68+
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
69+
elif quantization_type == "QAT":
70+
prepared_graph = prepare_qat_pt2e(pre_autograd_aten_dialect, quantizer)
71+
72+
prepared_graph(*example_inputs)
73+
converted_graph = convert_pt2e(prepared_graph)
74+
return converted_graph
75+
76+
@pytest.mark.parametrize(
77+
"quantizer_name, quantization_type",
78+
itertools.product(
79+
("XNNPack", "CoreML"),
80+
("PTQ", "QAT")
81+
)
82+
)
83+
def test_conv_relu(self, quantizer_name, quantization_type):
2784
SHAPE = (1, 3, 256, 256)
2885

2986
class Model(torch.nn.Module):
@@ -40,12 +97,49 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4097

4198
model = Model()
4299

43-
example_args = (torch.randn(SHAPE),)
44-
pre_autograd_aten_dialect = capture_pre_autograd_graph(model, example_args)
100+
example_inputs = (torch.randn(SHAPE),)
101+
converted_graph = self.make_torch_quantized_graph(
102+
model,
103+
example_inputs,
104+
quantizer_name,
105+
quantization_type,
106+
)
45107

46-
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
47-
prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer)
48-
converted_graph = convert_pt2e(prepared_graph)
108+
self.run_compare_torch(
109+
SHAPE,
110+
converted_graph,
111+
frontend=TorchFrontend.EXIR,
112+
backend=("mlprogram", "fp16"),
113+
minimum_deployment_target=ct.target.iOS17,
114+
)
115+
116+
@pytest.mark.parametrize(
117+
"quantizer_name, quantization_type",
118+
itertools.product(
119+
("XNNPack", "CoreML"),
120+
("PTQ", "QAT")
121+
)
122+
)
123+
def test_linear(self, quantizer_name, quantization_type):
124+
SHAPE = (1, 5)
125+
126+
class Model(torch.nn.Module):
127+
def __init__(self) -> None:
128+
super().__init__()
129+
self.linear = torch.nn.Linear(5, 10)
130+
131+
def forward(self, x: torch.Tensor) -> torch.Tensor:
132+
return self.linear(x)
133+
134+
model = Model()
135+
136+
example_inputs = (torch.randn(SHAPE),)
137+
converted_graph = self.make_torch_quantized_graph(
138+
model,
139+
example_inputs,
140+
quantizer_name,
141+
quantization_type,
142+
)
49143

50144
self.run_compare_torch(
51145
SHAPE,
Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,57 @@
1-
# Copyright (c) 2023, Apple Inc. All rights reserved.
1+
# Copyright (c) 2024, Apple Inc. All rights reserved.
22
#
33
# Use of this source code is governed by a BSD-3-clause license that can be
44
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
55

6+
import logging as _logging
7+
from collections import OrderedDict as _OrderedDict
68
from typing import Any as _Any
79

10+
_logger = _logging.getLogger(__name__)
11+
812

913
def get_str(val: _Any):
1014
if isinstance(val, float):
1115
return f"{val:.5f}"
1216
return str(val)
17+
18+
19+
class RegistryMixin:
20+
REGISTRY = None
21+
22+
@classmethod
23+
def register(cls, name: str):
24+
if cls.REGISTRY is None:
25+
cls.REGISTRY = _OrderedDict()
26+
27+
def inner_wrapper(wrapped_obj):
28+
if name in cls.REGISTRY:
29+
_logger.warning(
30+
f"Name: {name} is already registered with object: {cls.REGISTRY[name].__name__} "
31+
f"in registry: {cls.__name__}"
32+
f"Over-writing the name with new class: {wrapped_obj.__name__}"
33+
)
34+
cls.REGISTRY[name] = wrapped_obj
35+
return wrapped_obj
36+
37+
return inner_wrapper
38+
39+
@classmethod
40+
def _get_object(cls, name: str):
41+
if name in cls.REGISTRY:
42+
return cls.REGISTRY[name]
43+
raise NotImplementedError(
44+
f"No object is registered with name: {name} in registry {cls.__name__}."
45+
)
46+
47+
48+
class ClassRegistryMixin(RegistryMixin):
49+
@classmethod
50+
def get_class(cls, name: str):
51+
return cls._get_object(name)
52+
53+
54+
class FunctionRegistryMixin(RegistryMixin):
55+
@classmethod
56+
def get_function(cls, name: str):
57+
return cls._get_object(name)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) 2024, Apple Inc. All rights reserved.
2+
#
3+
# Use of this source code is governed by a BSD-3-clause license that can be
4+
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5+
6+
from typing import Optional as _Optional
7+
8+
import torch as _torch
9+
import torch.ao.quantization as _aoquant
10+
from attr import define as _define
11+
from torch.ao.quantization.quantizer.quantizer import (
12+
QuantizationSpec as _TorchQuantizationSpec,
13+
)
14+
15+
from coremltools.optimize.torch.quantization.quantization_config import (
16+
ModuleLinearQuantizerConfig as _ModuleLinearQuantizerConfig,
17+
)
18+
from coremltools.optimize.torch.quantization.quantization_config import ObserverType as _ObserverType
19+
from coremltools.optimize.torch.quantization.quantization_config import (
20+
QuantizationScheme as _QuantizationScheme,
21+
)
22+
23+
24+
@_define
25+
class AnnotationConfig:
26+
"""
27+
Module/Operator level configuration class for :py:class:`CoreMLQuantizer`.
28+
29+
For each module/operator, defines the dtype, quantization scheme and observer type
30+
for input(s), output and weights (if any).
31+
"""
32+
33+
input_activation: _Optional[_TorchQuantizationSpec] = None
34+
output_activation: _Optional[_TorchQuantizationSpec] = None
35+
weight: _Optional[_TorchQuantizationSpec] = None
36+
37+
@staticmethod
38+
def _normalize_dtype(dtype: _torch.dtype) -> _torch.dtype:
39+
"""
40+
PyTorch export quantizer only supports uint8 and int8 data types,
41+
so we map the quantized dtypes to the corresponding supported dtype.
42+
"""
43+
dtype_map = {
44+
_torch.quint8: _torch.uint8,
45+
_torch.qint8: _torch.int8,
46+
}
47+
return dtype_map.get(dtype, dtype)
48+
49+
@classmethod
50+
def from_quantization_config(
51+
cls,
52+
quantization_config: _Optional[_ModuleLinearQuantizerConfig],
53+
) -> _Optional["AnnotationConfig"]:
54+
"""
55+
Creates a :py:class:`AnnotationConfig` from ``ModuleLinearQuantizerConfig``
56+
"""
57+
if (
58+
quantization_config is None
59+
or quantization_config.weight_dtype == _torch.float32
60+
):
61+
return None
62+
63+
# Activation QSpec
64+
if quantization_config.activation_dtype == _torch.float32:
65+
output_activation_qspec = None
66+
else:
67+
activation_qscheme = _QuantizationScheme.get_qscheme(
68+
quantization_config.quantization_scheme,
69+
is_per_channel=False,
70+
)
71+
activation_dtype = cls._normalize_dtype(
72+
quantization_config.activation_dtype
73+
)
74+
output_activation_qspec = _TorchQuantizationSpec(
75+
observer_or_fake_quant_ctr=_aoquant.FakeQuantize.with_args(
76+
observer=_ObserverType.get_observer(
77+
quantization_config.activation_observer,
78+
is_per_channel=False,
79+
),
80+
dtype=activation_dtype,
81+
qscheme=activation_qscheme,
82+
),
83+
dtype=activation_dtype,
84+
qscheme=activation_qscheme,
85+
)
86+
87+
# Weight QSpec
88+
weight_qscheme = _QuantizationScheme.get_qscheme(
89+
quantization_config.quantization_scheme,
90+
is_per_channel=quantization_config.weight_per_channel,
91+
)
92+
weight_dtype = cls._normalize_dtype(quantization_config.weight_dtype)
93+
weight_qspec = _TorchQuantizationSpec(
94+
observer_or_fake_quant_ctr=_aoquant.FakeQuantize.with_args(
95+
observer=_ObserverType.get_observer(
96+
quantization_config.weight_observer,
97+
is_per_channel=quantization_config.weight_per_channel,
98+
),
99+
dtype=weight_dtype,
100+
qscheme=weight_qscheme,
101+
),
102+
dtype=weight_dtype,
103+
qscheme=weight_qscheme,
104+
)
105+
return AnnotationConfig(
106+
input_activation=output_activation_qspec,
107+
output_activation=output_activation_qspec,
108+
weight=weight_qspec,
109+
)

0 commit comments

Comments
 (0)