Skip to content

Commit 5c64718

Browse files
committed
Add a way to do power of 2 scaling
stack-info: PR: #2256, branch: drisspg/stack/57
1 parent d0b71cc commit 5c64718

File tree

4 files changed

+64
-3
lines changed

4 files changed

+64
-3
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
quantize_affine_float8,
4949
)
5050
from torchao.utils import (
51+
TORCH_VERSION_AT_LEAST_2_8,
5152
is_sm_at_least_89,
5253
is_sm_at_least_90,
5354
)
@@ -355,6 +356,59 @@ def test_mm_float8dq_per_row(
355356
@unittest.skipIf(
356357
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
357358
)
359+
@common_utils.parametrize(
360+
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
361+
)
362+
@unittest.skipIf(
363+
not TORCH_VERSION_AT_LEAST_2_8, "Requires PyTorch 2.8+ with e8m0 support"
364+
)
365+
def test_fp8_e8m0_scale_dtype(self, granularity):
366+
"""Test float8 quantization with e8m0 scale dtype on PyTorch 2.8+"""
367+
device = "cuda"
368+
dtype = torch.bfloat16
369+
in_features, out_features = 256, 512
370+
371+
# Create model
372+
model = ToyLinearModel(in_features, out_features).to(device).to(dtype)
373+
quant_model = copy.deepcopy(model)
374+
375+
# Create config with e8m0 scale dtype
376+
config = Float8DynamicActivationFloat8WeightConfig(
377+
granularity=granularity, scale_dtype=torch.float8_e8m0fnu
378+
)
379+
380+
# Quantize the model
381+
quantize_(quant_model, config)
382+
383+
# Verify that the scale dtype is correctly set
384+
for layer_name in ["linear1", "linear2"]:
385+
layer = getattr(quant_model, layer_name)
386+
weight_impl = layer.weight.original_weight_tensor.tensor_impl
387+
388+
# All though we specify w/ e8m0 we still cast to fp32
389+
self.assertEqual(weight_impl.scale.dtype, torch.float32)
390+
391+
# Verify scale is power of 2 (requirement for e8m0)
392+
scale_values = weight_impl.scale.float()
393+
log2_scales = torch.log2(scale_values)
394+
self.assertTrue(
395+
torch.allclose(log2_scales, torch.round(log2_scales), atol=0),
396+
"e8m0 scales should be powers of 2",
397+
)
398+
399+
# Test forward pass
400+
input_tensor = torch.randn(32, in_features, device=device, dtype=dtype)
401+
402+
with torch.no_grad():
403+
output = model(input_tensor)
404+
output_quant = quant_model(input_tensor)
405+
406+
# Verify output shape and that computation completes without error
407+
expected_shape = (32, in_features) # ToyLinearModel returns to original size
408+
self.assertEqual(output.shape, expected_shape)
409+
error = compute_error(output, output_quant)
410+
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
411+
358412
@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
359413
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
360414
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,10 @@ def from_hp_to_floatx(
463463
original_shape = input_float.shape
464464
input_float = _layout.pre_process(input_float)
465465
scale = choose_qparams_affine_float8(
466-
input_float, float8_dtype=target_dtype, block_size=block_size
466+
input_float,
467+
float8_dtype=target_dtype,
468+
block_size=block_size,
469+
scale_dtype=scale_dtype,
467470
)
468471
data = quantize_affine_float8(input_float, scale, target_dtype)
469472
data, scale, zero_point = _layout.post_process(

torchao/quantization/quant_api.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,7 +1412,7 @@ def _float8_weight_only_quant_tensor(weight, config):
14121412
input_float=weight,
14131413
block_size=block_size,
14141414
target_dtype=config.weight_dtype,
1415-
scale_dtype=None,
1415+
scale_dtype=torch.float32,
14161416
_layout=Float8Layout(mm_config=None),
14171417
)
14181418
return new_weight
@@ -1519,6 +1519,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15191519
only PerTensor and PerRow are supported.
15201520
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
15211521
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
1522+
scale_dtype: By default we set to fp32, if a user is on 12.8 and sets it to e8m0 we well ensure power of 2 scaling
15221523
15231524
"""
15241525

@@ -1529,6 +1530,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15291530
] = None
15301531
mm_config: Optional[Float8MMConfig] = None
15311532
set_inductor_config: bool = True
1533+
scale_dtype: torch.dtype = torch.float32
15321534

15331535
def __post_init__(self):
15341536
if self.mm_config is None:
@@ -1549,6 +1551,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15491551
weight_dtype = config.weight_dtype
15501552
granularity = config.granularity
15511553
mm_config = config.mm_config
1554+
scale_dtype = config.scale_dtype
15521555

15531556
# Ensure works on device
15541557
_check_hardware_support(granularity)
@@ -1570,7 +1573,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
15701573
input_float=weight,
15711574
block_size=block_size,
15721575
target_dtype=weight_dtype,
1573-
scale_dtype=torch.float32,
1576+
scale_dtype=scale_dtype,
15741577
_layout=Float8Layout(mm_config=mm_config),
15751578
)
15761579

torchao/quantization/quant_primitives.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,7 @@ def choose_qparams_affine_float8(
20052005
# Shielding for Version > 2.8
20062006
assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported"
20072007
scale = torch.exp2(torch.round(torch.log2(scale)))
2008+
20082009
return scale.to(dtype=torch.float32)
20092010

20102011

0 commit comments

Comments
 (0)