Skip to content

Add a way to do power of 2 scaling #2256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: drisspg/stack/56
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
quantize_affine_float8,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_89,
is_sm_at_least_90,
)
Expand Down Expand Up @@ -355,6 +356,59 @@ def test_mm_float8dq_per_row(
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize(
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
)
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_8, "Requires PyTorch 2.8+ with e8m0 support"
)
def test_fp8_e8m0_scale_dtype(self, granularity):
"""Test float8 quantization with e8m0 scale dtype on PyTorch 2.8+"""
device = "cuda"
dtype = torch.bfloat16
in_features, out_features = 256, 512

# Create model
model = ToyLinearModel(in_features, out_features).to(device).to(dtype)
quant_model = copy.deepcopy(model)

# Create config with e8m0 scale dtype
config = Float8DynamicActivationFloat8WeightConfig(
granularity=granularity, scale_dtype=torch.float8_e8m0fnu
)

# Quantize the model
quantize_(quant_model, config)

# Verify that the scale dtype is correctly set
for layer_name in ["linear1", "linear2"]:
layer = getattr(quant_model, layer_name)
weight_impl = layer.weight.original_weight_tensor.tensor_impl

# All though we specify w/ e8m0 we still cast to fp32
self.assertEqual(weight_impl.scale.dtype, torch.float32)

# Verify scale is power of 2 (requirement for e8m0)
scale_values = weight_impl.scale.float()
log2_scales = torch.log2(scale_values)
self.assertTrue(
torch.allclose(log2_scales, torch.round(log2_scales), atol=0),
"e8m0 scales should be powers of 2",
)

# Test forward pass
input_tensor = torch.randn(32, in_features, device=device, dtype=dtype)

with torch.no_grad():
output = model(input_tensor)
output_quant = quant_model(input_tensor)

# Verify output shape and that computation completes without error
expected_shape = (32, in_features) # ToyLinearModel returns to original size
self.assertEqual(output.shape, expected_shape)
error = compute_error(output, output_quant)
assert error > 20, f"Quantization error is too high got a SQNR of {error}"

@common_utils.parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@common_utils.parametrize("output_dtype", [torch.float32, torch.bfloat16])
@common_utils.parametrize("block_size", [None, (1, 32), (2, 16), (4, 8)])
Expand Down
5 changes: 4 additions & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,10 @@ def from_hp_to_floatx(
original_shape = input_float.shape
input_float = _layout.pre_process(input_float)
scale = choose_qparams_affine_float8(
input_float, float8_dtype=target_dtype, block_size=block_size
input_float,
float8_dtype=target_dtype,
block_size=block_size,
scale_dtype=scale_dtype,
)
data = quantize_affine_float8(input_float, scale, target_dtype)
data, scale, zero_point = _layout.post_process(
Expand Down
7 changes: 5 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@ def _float8_weight_only_quant_tensor(weight, config):
input_float=weight,
block_size=block_size,
target_dtype=config.weight_dtype,
scale_dtype=None,
scale_dtype=torch.float32,
_layout=Float8Layout(mm_config=None),
)
return new_weight
Expand Down Expand Up @@ -1519,6 +1519,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
only PerTensor and PerRow are supported.
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
scale_dtype: By default we set to fp32, if a user is on 12.8 and sets it to e8m0 we well ensure power of 2 scaling

"""

Expand All @@ -1529,6 +1530,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
] = None
mm_config: Optional[Float8MMConfig] = None
set_inductor_config: bool = True
scale_dtype: torch.dtype = torch.float32

def __post_init__(self):
if self.mm_config is None:
Expand All @@ -1549,6 +1551,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
weight_dtype = config.weight_dtype
granularity = config.granularity
mm_config = config.mm_config
scale_dtype = config.scale_dtype

# Ensure works on device
_check_hardware_support(granularity)
Expand All @@ -1570,7 +1573,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
input_float=weight,
block_size=block_size,
target_dtype=weight_dtype,
scale_dtype=torch.float32,
scale_dtype=scale_dtype,
_layout=Float8Layout(mm_config=mm_config),
)

Expand Down
1 change: 1 addition & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,7 @@ def choose_qparams_affine_float8(
# Shielding for Version > 2.8
assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported"
scale = torch.exp2(torch.round(torch.log2(scale)))

return scale.to(dtype=torch.float32)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is a really great way to express this, this switch back is the only spooky part

Copy link
Contributor

@danielvegamyhre danielvegamyhre May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so the api to use power of 2 scales for inference would be to use float8_e8m0 as the scale dtype, which is all exponent bits so only powers of 2, is that right? This is clever but does require a step of indirection that may be confusing to users, IMO it would be better to have the API be consistent with training, where it just a config option round_scales_to_powers_of_2.



Expand Down
Loading