diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index de9a1ef4a7..9d453102cd 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1270,6 +1270,8 @@ def choose_qparams_affine_with_min_max( if eps is None: eps = torch.finfo(min_val.dtype).eps + scale_device = min_val.device + if preserve_zero: min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) @@ -1316,7 +1318,9 @@ def choose_qparams_affine_with_min_max( scale = torch.clamp(scale, min=eps) else: assert mapping_type == MappingType.ASYMMETRIC - scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = (max_val_pos - min_val_neg) / torch.tensor( + float(quant_max - quant_min), dtype=scale_dtype, device=scale_device + ) scale = torch.clamp(scale, min=eps) if zero_point_domain == ZeroPointDomain.NONE: zero_point = None