Skip to content

Commit e19a0f9

Browse files
committed
Update callsite for pt2e quant
Summary: We just removed pt2e quant from pytorch/pytorch in pytorch/pytorch#169151 This PR updated the pt2e quantization callsites to torchao Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 0434e85 commit e19a0f9

File tree

6 files changed

+40
-23
lines changed

6 files changed

+40
-23
lines changed

coremltools/converters/mil/frontend/torch/test/test_torch_export_quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
)
3030

3131
from torch.export import export_for_training
32-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e
33-
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
34-
XNNPACKQuantizer,
35-
get_symmetric_quantization_config,
32+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e, prepare_qat_pt2e
33+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
34+
get_symmetric_quantization_config,
35+
XNNPACKQuantizer,
3636
)
3737

3838
import coremltools as ct

coremltools/optimize/torch/quantization/_annotation_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch as _torch
99
import torch.ao.quantization as _aoquant
1010
from attr import define as _define
11-
from torch.ao.quantization.quantizer.quantizer import (
11+
from torchao.quantization.pt2e.quantizer import (
1212
QuantizationSpec as _TorchQuantizationSpec,
1313
)
1414

coremltools/optimize/torch/quantization/_coreml_quantizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from typing import Optional as _Optional
1010

1111
import torch as _torch
12-
from torch.ao.quantization.quantizer.quantizer import Quantizer as _TorchQuantizer
13-
from torch.ao.quantization.quantizer.xnnpack_quantizer import _get_module_name_filter
12+
from torchao.quantization.pt2e.quantizer import Quantizer as _TorchQuantizer
13+
from torchao.quantization.pt2e.quantizer import _get_module_name_filter
1414
from torch.fx import Node as _Node
1515

1616
import coremltools.optimize.torch.quantization._coreml_quantizer_utils as _annotation_utils
@@ -519,7 +519,7 @@ class CoreMLQuantizer(_TorchQuantizer):
519519
"""
520520
Annotates all recognized patterns using ``config``.
521521
522-
Extends py:class:`torch.ao.quantization.quantizer.quantizer.Quantizer` to
522+
Extends py:class:`torchao.quantization.pt2e.quantizer.Quantizer` to
523523
add support for quantization patterns supported by Core ML.
524524
525525
Use it in conjunction with PyTorch 2.0 ``prepare_pt2e`` and ``prepare_qat_pt2e`` APIs
@@ -532,7 +532,7 @@ class CoreMLQuantizer(_TorchQuantizer):
532532
533533
import torch.nn as nn
534534
from torch.export import export_for_training
535-
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
535+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
536536
537537
from coremltools.optimize.torch.quantization._coreml_quantizer import CoreMLQuantizer
538538

coremltools/optimize/torch/quantization/_coreml_quantizer_utils.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,22 @@
1717
if _IS_TORCH_OLDER_THAN_2_4:
1818
from torch.ao.quantization.pt2e.utils import get_aten_graph_module
1919
else:
20-
from torch.ao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern
20+
from torchao.quantization.pt2e.utils import _get_aten_graph_module_for_pattern
2121

22-
from torch.ao.quantization.quantizer.quantizer import (
22+
from torchao.quantization.pt2e.quantizer import (
2323
FixedQParamsQuantizationSpec as _FixedQParamsQuantizationSpec,
2424
)
25-
from torch.ao.quantization.quantizer.quantizer import (
25+
from torchao.quantization.pt2e.quantizer import (
2626
QuantizationAnnotation as _QuantizationAnnotation,
2727
)
28-
from torch.ao.quantization.quantizer.quantizer import QuantizationSpec as _TorchQuantizationSpec
29-
from torch.ao.quantization.quantizer.quantizer import (
28+
from torchao.quantization.pt2e.quantizer import QuantizationSpec as _TorchQuantizationSpec
29+
from torchao.quantization.pt2e.quantizer import (
3030
QuantizationSpecBase as _TorchQuantizationSpecBase,
3131
)
32-
from torch.ao.quantization.quantizer.quantizer import (
32+
from torchao.quantization.pt2e.quantizer import (
3333
SharedQuantizationSpec as _SharedQuantizationSpec,
3434
)
35-
from torch.ao.quantization.quantizer.xnnpack_quantizer import _get_module_name_filter
36-
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
37-
_is_annotated,
38-
_mark_nodes_as_annotated,
39-
)
35+
from torchao.quantization.pt2e.quantizer import get_module_name_filter
4036
from torch.fx import Node as _Node
4137
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
4238
SubgraphMatcherWithNameNodeMap as _SubgraphMatcherWithNameNodeMap,
@@ -49,6 +45,27 @@
4945
AnnotationConfig as _AnnotationConfig,
5046
)
5147

48+
def _is_annotated(nodes: list[_Node]):
49+
"""
50+
Given a list of nodes (that represents an operator pattern),
51+
check if any of the node is annotated, return True if any of the node
52+
is annotated, otherwise return False
53+
"""
54+
for node in nodes:
55+
if annotated or (
56+
"quantization_annotation" in node.meta
57+
and node.meta["quantization_annotation"]._annotated
58+
):
59+
return True
60+
return False
61+
62+
def _mark_nodes_as_annotated(nodes: list[_Node]):
63+
for node in nodes:
64+
if node is not None:
65+
if "quantization_annotation" not in node.meta:
66+
node.meta["quantization_annotation"] = _QuantizationAnnotation()
67+
node.meta["quantization_annotation"]._annotated = True
68+
5269
# All activations recognized for conv-act/conv-bn-act patterns
5370
_supported_activations = (
5471
_F.relu,
@@ -185,7 +202,7 @@ def get_not_object_type_or_name_filter(
185202
type in ``tp_list``.
186203
"""
187204
object_type_filters = [get_object_type_filter(tp) for tp in tp_list]
188-
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
205+
module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]
189206

190207
def not_object_type_or_name_filter(n: _Node) -> bool:
191208
return not any(f(n) for f in object_type_filters + module_name_list_filters)

coremltools/test/optimize/torch/quantization/test_coreml_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from coremltools._deps import _HAS_TORCH_EXPORT_API
2121
if _HAS_TORCH_EXPORT_API:
2222
from torch.export import export_for_training
23-
from torch.ao.quantization.quantize_pt2e import (
23+
from torchao.quantization.pt2e.quantize_pt2e import (
2424
convert_pt2e,
2525
prepare_pt2e,
2626
prepare_qat_pt2e,

reqs/pytorch.pip

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ torchsr==1.0.4; platform_machine == "arm64"
1313
# TODO (rdar://141476729) support a more recent timm
1414
timm==0.6.13; platform_machine == "arm64"
1515

16-
torchao==0.10.0; platform_machine == "arm64" and python_version >= '3.10'
16+
torchao==0.15.0; platform_machine == "arm64" and python_version >= '3.10'

0 commit comments

Comments
 (0)