Skip to content

Commit a2b9952

Browse files
authored
Do not raise error when quant primitives are left after partitioner (#10573)
Currently _sanity_check_graph_for_non_decomp_ops raises an Exception if a delegate asks an op for preservation, but doesn't lower it. In general, this is a sensible thing to do, but for quant primitives, it is less sensible. Since what gets lowered are patterns involving the quant primitives and FP32 ops. XNNPACK asks that quant primitives be preserved, and so if a quant primitive is not lowered (e.g., it is part of embedding quant), an error is thrown. In this PR, we: * Define a central location for _QUANT_PRIMITIVES (with TODO task of moving this to torchao) * Use these _QUANT_PRIMITIVES to avoid raising an error in _sanity_check_graph_for_non_decomp_ops * Use _QUANT_PRIMITIVES in tracer.py to no decompose during to_edge and const_prop_pass to not constant propagate (this logic existed previously, but is being rewritten using the central _QUANT_PRIMITIVES list).
1 parent 48ad9f6 commit a2b9952

File tree

8 files changed

+48
-44
lines changed

8 files changed

+48
-44
lines changed

exir/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ python_library(
1515
":types",
1616
"//caffe2:torch",
1717
"//executorch/exir/operator:convert",
18+
"//executorch/exir/operator::util",
1819
"//executorch/extension/pytree:pylib",
19-
"//pytorch/ao:torchao",
2020
],
2121
)
2222

exir/operator/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ python_library(
3232
],
3333
deps = [
3434
"//caffe2/torchgen:torchgen",
35+
"//pytorch/ao:torchao",
36+
"//caffe2:torch",
3537
],
3638
)
3739

exir/operator/util.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
# pyre-strict
88

9+
import torch
10+
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
911
from torchgen.model import FunctionSchema, SchemaKind
1012
from torchgen.native_function_generation import (
1113
functional_to_out_signature,
@@ -39,3 +41,28 @@ def gen_out_variant_schema(func_op_schema: str) -> str:
3941
raise RuntimeError(f"SchemaKind: {func.kind()} is not supported")
4042

4143
return f"{namespace}::{schema}" if namespace else schema
44+
45+
46+
# TODO: move to torchao
47+
_QUANT_PRIMITIVES = [
48+
torch.ops.quantized_decomposed.dequantize_per_channel.default,
49+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
50+
torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
51+
torch.ops.quantized_decomposed.convert_element_type.no_fuse,
52+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
53+
torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
54+
torch.ops.quantized_decomposed.quantize_per_channel.default,
55+
torch.ops.quantized_decomposed.choose_qparams.tensor,
56+
]
57+
try:
58+
import torchao # noqa: F401
59+
60+
_QUANT_PRIMITIVES.extend(
61+
[
62+
torch.ops.torchao.dequantize_affine.default,
63+
torch.ops.torchao.quantize_affine.default,
64+
torch.ops.torchao.choose_qparams_affine.default,
65+
]
66+
)
67+
except ImportError:
68+
pass

exir/passes/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ python_library(
119119
"//caffe2:torch",
120120
"//executorch/exir/dialects:lib",
121121
"//executorch/exir/dialects/edge:lib",
122+
"//executorch/exir/operator::util",
123+
"//executorch/exir/passes:replace_aten_with_edge_pass",
122124
],
123125
)
124126

exir/passes/constant_prop_pass.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import torch
1414
from executorch.exir.dialects._ops import ops as exir_ops
1515
from executorch.exir.dialects.edge._ops import EdgeOpOverload
16+
from executorch.exir.operator.util import _QUANT_PRIMITIVES
17+
from executorch.exir.passes.replace_aten_with_edge_pass import aten_to_edge
1618
from torch._export.utils import (
1719
get_buffer,
1820
get_lifted_tensor_constant,
@@ -25,35 +27,13 @@
2527
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
2628
from torch.utils import _pytree as pytree
2729

28-
2930
# Avoid propagating constants for `exir.ops.edge.aten.full.default`.
3031
# Propagating aten.full can significantly increase compiled model size.
3132
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}
3233

3334
# Do not const prop quantization primitives
34-
_QDQ_OPS = [
35-
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
36-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
37-
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
38-
exir_ops.edge.quantized_decomposed.convert_element_type.no_fuse,
39-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
40-
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
41-
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
42-
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
43-
]
44-
try:
45-
import torchao # noqa: F401
46-
47-
_QDQ_OPS.extend(
48-
[
49-
exir_ops.edge.torchao.dequantize_affine.default,
50-
exir_ops.edge.torchao.quantize_affine.default,
51-
exir_ops.edge.torchao.choose_qparams_affine.default,
52-
]
53-
)
54-
except ImportError:
55-
pass
56-
_DEFAULT_SKIP_TARGETS.update(set(_QDQ_OPS))
35+
_QUANT_PRIMITIVES_EDGE = [aten_to_edge(op) for op in _QUANT_PRIMITIVES]
36+
_DEFAULT_SKIP_TARGETS.update(set(_QUANT_PRIMITIVES_EDGE))
5737

5838

5939
_PRIMITIVE_TYPES = (

exir/program/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ python_library(
3535
"//executorch/exir/capture:config",
3636
"//executorch/exir/emit:emit",
3737
"//executorch/exir/emit:lib",
38+
"//executorch/exir/operator:util",
3839
"//executorch/exir/passes:insert_write_back_for_buffers_pass",
3940
"//executorch/exir/passes:lib",
4041
"//executorch/exir/passes:normalize_view_copy_base_pass",

exir/program/_program.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from executorch.exir.error import ExportError
3333
from executorch.exir.graph_module import get_control_flow_submodules
3434
from executorch.exir.operator.convert import _pybind_schema_to_native_schema
35+
from executorch.exir.operator.util import _QUANT_PRIMITIVES
3536
from executorch.exir.pass_base import PassBase
3637
from executorch.exir.pass_manager import PassType
3738
from executorch.exir.passes import (
@@ -971,10 +972,14 @@ def _sanity_check_graph_for_non_decomp_ops(
971972
ops_set_to_not_decompose = {
972973
aten_to_edge(op) for op in ops_set_to_not_decompose
973974
}.union(ops_set_to_not_decompose)
975+
976+
quant_primitives = {aten_to_edge(op) for op in _QUANT_PRIMITIVES}
974977
for node in program.graph_module.graph.nodes:
975978
is_op_supported = check_op_support(node) if check_op_support else True
976979
if (
977-
node.op == "call_function" and node.target in ops_set_to_not_decompose
980+
node.op == "call_function"
981+
and node.target in ops_set_to_not_decompose
982+
and node.target not in quant_primitives
978983
) and is_op_supported:
979984
warning_str = (
980985
f"Node {node} with op {node.target} was not decomposed or delegated.\n"
@@ -988,7 +993,9 @@ def _sanity_check_graph_for_non_decomp_ops(
988993
for node in submod.graph.nodes:
989994
is_op_supported = check_op_support(node) if check_op_support else True
990995
if (
991-
node.op == "call_function" and node.target in ops_set_to_not_decompose
996+
node.op == "call_function"
997+
and node.target in ops_set_to_not_decompose
998+
and node.target not in quant_primitives
992999
) and is_op_supported:
9931000
warning_str = (
9941001
f"Node {node} with op {node.target} was not decomposed or delegated.\n"

exir/tracer.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from executorch.exir.error import ExportError, ExportErrorType, InternalError
4242
from executorch.exir.graph_module import LeafValue
4343
from executorch.exir.operator.convert import is_out_variant
44+
from executorch.exir.operator.util import _QUANT_PRIMITIVES
4445
from executorch.exir.types import ValueSpec
4546

4647
from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual
@@ -54,7 +55,6 @@
5455

5556
from typing_extensions import TypeAlias
5657

57-
5858
Value: TypeAlias = Union[
5959
LeafValue,
6060
Tuple["Value", ...],
@@ -643,22 +643,7 @@ def _default_decomposition_table(
643643
# pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir....
644644

645645
never_decompose = []
646-
try:
647-
# Do not decompose torchao quant primitives
648-
# They have decompositions registered for inductor/CUDA, but in ExecuTorch we
649-
# just pattern match them and lower to delegates
650-
import torchao # noqa: F401
651-
652-
never_decompose.extend(
653-
[
654-
torch.ops.torchao.quantize_affine.default,
655-
torch.ops.torchao.dequantize_affine.default,
656-
torch.ops.torchao.choose_qparams_affine.default,
657-
]
658-
)
659-
except:
660-
pass
661-
646+
never_decompose.extend(_QUANT_PRIMITIVES)
662647
for op in never_decompose:
663648
decomps.pop(op, None)
664649
return decomps # pyre-fixme[7]

0 commit comments

Comments
 (0)