Skip to content

Commit 6af4a4c

Browse files
Pulkit Agrawalpulkital
authored andcommitted
Add support for PyTorch Export Quantizer
1 parent b416f36 commit 6af4a4c

File tree

5 files changed

+1765
-1
lines changed

5 files changed

+1765
-1
lines changed
Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,57 @@
1-
# Copyright (c) 2023, Apple Inc. All rights reserved.
1+
# Copyright (c) 2024, Apple Inc. All rights reserved.
22
#
33
# Use of this source code is governed by a BSD-3-clause license that can be
44
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
55

6+
import logging as _logging
7+
from collections import OrderedDict as _OrderedDict
68
from typing import Any as _Any
79

10+
_logger = _logging.getLogger(__name__)
11+
812

913
def get_str(val: _Any):
1014
if isinstance(val, float):
1115
return f"{val:.5f}"
1216
return str(val)
17+
18+
19+
class RegistryMixin:
20+
REGISTRY = None
21+
22+
@classmethod
23+
def register(cls, name: str):
24+
if cls.REGISTRY is None:
25+
cls.REGISTRY = _OrderedDict()
26+
27+
def inner_wrapper(wrapped_obj):
28+
if name in cls.REGISTRY:
29+
_logger.warning(
30+
f"Name: {name} is already registered with object: {cls.REGISTRY[name].__name__} "
31+
f"in registry: {cls.__name__}"
32+
f"Over-writing the name with new class: {wrapped_obj.__name__}"
33+
)
34+
cls.REGISTRY[name] = wrapped_obj
35+
return wrapped_obj
36+
37+
return inner_wrapper
38+
39+
@classmethod
40+
def _get_object(cls, name: str):
41+
if name in cls.REGISTRY:
42+
return cls.REGISTRY[name]
43+
raise NotImplementedError(
44+
f"No object is registered with name: {name} in registry {cls.__name__}."
45+
)
46+
47+
48+
class ClassRegistryMixin(RegistryMixin):
49+
@classmethod
50+
def get_class(cls, name: str):
51+
return cls._get_object(name)
52+
53+
54+
class FunctionRegistryMixin(RegistryMixin):
55+
@classmethod
56+
def get_function(cls, name: str):
57+
return cls._get_object(name)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) 2024, Apple Inc. All rights reserved.
2+
#
3+
# Use of this source code is governed by a BSD-3-clause license that can be
4+
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5+
6+
from typing import Optional as _Optional
7+
8+
import torch as _torch
9+
import torch.ao.quantization as _aoquant
10+
from attr import define as _define
11+
from torch.ao.quantization.quantizer.quantizer import (
12+
QuantizationSpec as _TorchQuantizationSpec,
13+
)
14+
15+
from coremltools.optimize.torch.quantization.quantization_config import (
16+
ModuleLinearQuantizerConfig as _ModuleLinearQuantizerConfig,
17+
)
18+
from coremltools.optimize.torch.quantization.quantization_config import ObserverType as _ObserverType
19+
from coremltools.optimize.torch.quantization.quantization_config import (
20+
QuantizationScheme as _QuantizationScheme,
21+
)
22+
23+
24+
@_define
25+
class AnnotationConfig:
26+
"""
27+
Module/Operator level configuration class for :py:class:`CoreMLQuantizer`.
28+
29+
For each module/operator, defines the dtype, quantization scheme and observer type
30+
for input(s), output and weights (if any).
31+
"""
32+
33+
input_activation: _Optional[_TorchQuantizationSpec] = None
34+
output_activation: _Optional[_TorchQuantizationSpec] = None
35+
weight: _Optional[_TorchQuantizationSpec] = None
36+
37+
@staticmethod
38+
def _normalize_dtype(dtype: _torch.dtype) -> _torch.dtype:
39+
"""
40+
PyTorch export quantizer only supports uint8 and int8 data types,
41+
so we map the quantized dtypes to the corresponding supported dtype.
42+
"""
43+
dtype_map = {
44+
_torch.quint8: _torch.uint8,
45+
_torch.qint8: _torch.int8,
46+
}
47+
return dtype_map.get(dtype, dtype)
48+
49+
@classmethod
50+
def from_quantization_config(
51+
cls,
52+
quantization_config: _Optional[_ModuleLinearQuantizerConfig],
53+
) -> _Optional["AnnotationConfig"]:
54+
"""
55+
Creates a :py:class:`AnnotationConfig` from ``ModuleLinearQuantizerConfig``
56+
"""
57+
if (
58+
quantization_config is None
59+
or quantization_config.weight_dtype == _torch.float32
60+
):
61+
return None
62+
63+
# Activation QSpec
64+
if quantization_config.activation_dtype == _torch.float32:
65+
output_activation_qspec = None
66+
else:
67+
activation_qscheme = _QuantizationScheme.get_qscheme(
68+
quantization_config.quantization_scheme,
69+
is_per_channel=False,
70+
)
71+
activation_dtype = cls._normalize_dtype(
72+
quantization_config.activation_dtype
73+
)
74+
output_activation_qspec = _TorchQuantizationSpec(
75+
observer_or_fake_quant_ctr=_aoquant.FakeQuantize.with_args(
76+
observer=_ObserverType.get_observer(
77+
quantization_config.activation_observer,
78+
is_per_channel=False,
79+
),
80+
dtype=activation_dtype,
81+
qscheme=activation_qscheme,
82+
),
83+
dtype=activation_dtype,
84+
qscheme=activation_qscheme,
85+
)
86+
87+
# Weight QSpec
88+
weight_qscheme = _QuantizationScheme.get_qscheme(
89+
quantization_config.quantization_scheme,
90+
is_per_channel=quantization_config.weight_per_channel,
91+
)
92+
weight_dtype = cls._normalize_dtype(quantization_config.weight_dtype)
93+
weight_qspec = _TorchQuantizationSpec(
94+
observer_or_fake_quant_ctr=_aoquant.FakeQuantize.with_args(
95+
observer=_ObserverType.get_observer(
96+
quantization_config.weight_observer,
97+
is_per_channel=quantization_config.weight_per_channel,
98+
),
99+
dtype=weight_dtype,
100+
qscheme=weight_qscheme,
101+
),
102+
dtype=weight_dtype,
103+
qscheme=weight_qscheme,
104+
)
105+
return AnnotationConfig(
106+
input_activation=output_activation_qspec,
107+
output_activation=output_activation_qspec,
108+
weight=weight_qspec,
109+
)

0 commit comments

Comments
 (0)