Skip to content

Commit efac465

Browse files
authored
Add backward compatible types to pt2e prepare
Differential Revision: D75248288 Pull Request resolved: #2244
1 parent a776b1f commit efac465

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchao/quantization/pt2e/quantize_pt2e.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
if TORCH_VERSION_AT_LEAST_2_7:
1212
from .constant_fold import constant_fold
1313

14+
from typing import Union
15+
1416
from torch.fx import GraphModule, Node
1517
from torch.fx.passes.infra.pass_manager import PassManager
1618

@@ -39,7 +41,7 @@
3941

4042
def prepare_pt2e(
4143
model: GraphModule,
42-
quantizer: Quantizer,
44+
quantizer: Union[Quantizer, torch.ao.quantization.quantizer.quantizer.Quantizer],
4345
) -> GraphModule:
4446
"""Prepare a model for post training quantization
4547
@@ -127,7 +129,7 @@ def calibrate(model, data_loader):
127129

128130
def prepare_qat_pt2e(
129131
model: GraphModule,
130-
quantizer: Quantizer,
132+
quantizer: Union[Quantizer, torch.ao.quantization.quantizer.quantizer.Quantizer],
131133
) -> GraphModule:
132134
"""Prepare a model for quantization aware training
133135

0 commit comments

Comments
 (0)