Skip to content

FP8 PerRow quantization (CUDA capability>=9.0) #2566

@zzlin-0629

Description

@zzlin-0629

I found a description as below:

A8W8 Float8 Dynamic Quantization with Rowwise Scaling

for torch 2.5+

from torchao.quantization import quantize_, PerRow, Float8DynamicActivationFloat8WeightConfig
quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
Per-row scaling is only supported for bfloat16 weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required.

which said "CUDA compute capability 8.9 or greater is required.".But actually, I found that PerRow() needs CUDA compute capability >=9.0, as in the code

File "/opt/conda/lib/python3.11/site-packages/torchao/quantization/quant_api.py", line 1475, in _normalize_granularity
assert is_sm_at_least_90() or is_MI300(), (
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: PerRow quantization only works for CUDA>=9.0 and MI300+

I use torchao==0.11.0, so is there a typo mistake or the code was wrong?

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions