File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed
torchao/quantization/pt2e Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change 11
11
if TORCH_VERSION_AT_LEAST_2_7 :
12
12
from .constant_fold import constant_fold
13
13
14
+ from typing import Union
15
+
14
16
from torch .fx import GraphModule , Node
15
17
from torch .fx .passes .infra .pass_manager import PassManager
16
18
39
41
40
42
def prepare_pt2e (
41
43
model : GraphModule ,
42
- quantizer : Quantizer ,
44
+ quantizer : Union [ Quantizer , torch . ao . quantization . quantizer . quantizer . Quantizer ] ,
43
45
) -> GraphModule :
44
46
"""Prepare a model for post training quantization
45
47
@@ -127,7 +129,7 @@ def calibrate(model, data_loader):
127
129
128
130
def prepare_qat_pt2e (
129
131
model : GraphModule ,
130
- quantizer : Quantizer ,
132
+ quantizer : Union [ Quantizer , torch . ao . quantization . quantizer . quantizer . Quantizer ] ,
131
133
) -> GraphModule :
132
134
"""Prepare a model for quantization aware training
133
135
You can’t perform that action at this time.
0 commit comments