File tree Expand file tree Collapse file tree 1 file changed +7
-1
lines changed
Expand file tree Collapse file tree 1 file changed +7
-1
lines changed Original file line number Diff line number Diff line change @@ -873,7 +873,9 @@ def _choose_qparams_affine(
873873 f"Only symmetric quantization is supported for FP8 types, got { mapping_type } "
874874 )
875875
876+ scale_device = None
876877 if input is not None :
878+ scale_device = input .device
877879 if scale_dtype is None :
878880 scale_dtype = input .dtype
879881 if eps is None :
@@ -902,6 +904,8 @@ def _choose_qparams_affine(
902904 if eps is None :
903905 eps = torch .finfo (min_val .dtype ).eps
904906
907+ scale_device = min_val .device
908+
905909 if preserve_zero :
906910 min_val_neg = torch .min (min_val , torch .zeros_like (min_val ))
907911 max_val_pos = torch .max (max_val , torch .zeros_like (max_val ))
@@ -948,7 +952,9 @@ def _choose_qparams_affine(
948952 scale = torch .clamp (scale , min = eps )
949953 else :
950954 assert mapping_type == MappingType .ASYMMETRIC .name
951- scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
955+ scale = (max_val_pos - min_val_neg ) / torch .tensor (
956+ float (quant_max - quant_min ), dtype = scale_dtype , device = scale_device
957+ )
952958 scale = torch .clamp (scale , min = eps )
953959 if zero_point_domain == ZeroPointDomain .NONE .name :
954960 zero_point = None
You can’t perform that action at this time.
0 commit comments