Skip to content

Commit a07c9e2

Browse files
committed
Add a way to do power of 2 scaling
stack-info: PR: #2256, branch: drisspg/stack/57
1 parent b10c3cc commit a07c9e2

File tree

4 files changed

+73
-4
lines changed

4 files changed

+73
-4
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
choose_qparams_affine,
4545
)
4646
from torchao.utils import (
47+
TORCH_VERSION_AT_LEAST_2_8,
4748
is_sm_at_least_89,
4849
is_sm_at_least_90,
4950
)
@@ -347,6 +348,63 @@ def test_mm_float8dq_per_row(
347348
error = compute_error(ref_output, quant_output)
348349
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
349350

351+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
352+
@unittest.skipIf(
353+
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
354+
)
355+
@common_utils.parametrize(
356+
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
357+
)
358+
@unittest.skipIf(
359+
not TORCH_VERSION_AT_LEAST_2_8, "Requires PyTorch 2.8+ with e8m0 support"
360+
)
361+
def test_fp8_e8m0_scale_dtype(self, granularity):
362+
"""Test float8 quantization with e8m0 scale dtype on PyTorch 2.8+"""
363+
device = "cuda"
364+
dtype = torch.bfloat16
365+
in_features, out_features = 256, 512
366+
367+
# Create model
368+
model = ToyLinearModel(in_features, out_features).to(device).to(dtype)
369+
quant_model = copy.deepcopy(model)
370+
371+
# Create config with e8m0 scale dtype
372+
config = Float8DynamicActivationFloat8WeightConfig(
373+
granularity=granularity, scale_dtype=torch.float8_e8m0fnu
374+
)
375+
376+
# Quantize the model
377+
quantize_(quant_model, config)
378+
379+
# Verify that the scale dtype is correctly set
380+
for layer_name in ["linear1", "linear2"]:
381+
layer = getattr(quant_model, layer_name)
382+
weight_impl = layer.weight.original_weight_tensor.tensor_impl
383+
384+
# All though we specify w/ e8m0 we still cast to fp32
385+
self.assertEqual(weight_impl.scale.dtype, torch.float32)
386+
387+
# Verify scale is power of 2 (requirement for e8m0)
388+
scale_values = weight_impl.scale.float()
389+
log2_scales = torch.log2(scale_values)
390+
self.assertTrue(
391+
torch.allclose(log2_scales, torch.round(log2_scales), atol=0),
392+
"e8m0 scales should be powers of 2",
393+
)
394+
395+
# Test forward pass
396+
input_tensor = torch.randn(32, in_features, device=device, dtype=dtype)
397+
398+
with torch.no_grad():
399+
output = model(input_tensor)
400+
output_quant = quant_model(input_tensor)
401+
402+
# Verify output shape and that computation completes without error
403+
expected_shape = (32, in_features) # ToyLinearModel returns to original size
404+
self.assertEqual(output.shape, expected_shape)
405+
error = compute_error(output, output_quant)
406+
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
407+
350408

351409
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
352410

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,10 @@ def from_hp_to_floatx(
463463
original_shape = input_float.shape
464464
input_float = _layout.pre_process(input_float)
465465
scale = choose_qparams_affine_float8(
466-
input_float, float8_dtype=target_dtype, block_size=block_size
466+
input_float,
467+
float8_dtype=target_dtype,
468+
block_size=block_size,
469+
scale_dtype=scale_dtype,
467470
)
468471
data = quantize_affine_float8(input_float, scale, target_dtype)
469472
data, scale, zero_point = _layout.post_process(

torchao/quantization/quant_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,7 +1412,7 @@ def _float8_weight_only_quant_tensor(weight, config):
14121412
input_float=weight,
14131413
block_size=block_size,
14141414
target_dtype=config.weight_dtype,
1415-
scale_dtype=None,
1415+
scale_dtype=torch.float32,
14161416
_layout=Float8Layout(mm_config=None),
14171417
)
14181418
return new_weight
@@ -1519,6 +1519,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15191519
only PerTensor and PerRow are supported.
15201520
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
15211521
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1522+
scale_dtype: By default we set to fp32, if a user is on 12.8 and sets it to e8m0 we well ensure power of 2 scaling
15221523
15231524
"""
15241525

@@ -1529,6 +1530,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15291530
] = None
15301531
mm_config: Optional[Float8MMConfig] = None
15311532
set_inductor_config: bool = True
1533+
scale_dtype: torch.dtype = torch.float32
15321534

15331535
def __post_init__(self):
15341536
if self.mm_config is None:
@@ -1549,6 +1551,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15491551
weight_dtype = config.weight_dtype
15501552
granularity = config.granularity
15511553
mm_config = config.mm_config
1554+
scale_dtype = config.scale_dtype
15521555

15531556
# Ensure works on device
15541557
_check_hardware_support(granularity)
@@ -1570,7 +1573,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15701573
input_float=weight,
15711574
block_size=block_size,
15721575
target_dtype=weight_dtype,
1573-
scale_dtype=torch.float32,
1576+
scale_dtype=scale_dtype,
15741577
_layout=Float8Layout(mm_config=mm_config),
15751578
)
15761579

torchao/quantization/quant_primitives.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2001,7 +2001,12 @@ def choose_qparams_affine_float8(
20012001
]
20022002
scale = scale.reshape(output_shape)
20032003

2004-
return scale.to(dtype=scale_dtype)
2004+
if scale_dtype is not torch.float32:
2005+
# Shielding for Version > 2.8
2006+
assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported"
2007+
scale = torch.exp2(torch.round(torch.log2(scale)))
2008+
2009+
return scale.to(dtype=torch.float32)
20052010

20062011

20072012
def quantize_affine_float8(

0 commit comments

Comments
 (0)