Skip to content

Commit 0db5fc2

Browse files
YIWENX14facebook-github-bot
authored andcommitted
Fix device and dtype discrepancy in _choose_qparams_affine (#2210)
Summary: Pull Request resolved: #2210 Differential Revision: D74446877
1 parent 554cb60 commit 0db5fc2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torchao/quantization/quant_primitives.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,9 @@ def _choose_qparams_affine(
948948
scale = torch.clamp(scale, min=eps)
949949
else:
950950
assert mapping_type == MappingType.ASYMMETRIC.name
951-
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
951+
scale = (max_val_pos - min_val_neg) / torch.tensor(
952+
[float(quant_max - quant_min)], dtype=input.dtype, device=input.device
953+
)
952954
scale = torch.clamp(scale, min=eps)
953955
if zero_point_domain == ZeroPointDomain.NONE.name:
954956
zero_point = None

0 commit comments

Comments
 (0)