-
Notifications
You must be signed in to change notification settings - Fork 308
Open
Description
I found a description as below:
A8W8 Float8 Dynamic Quantization with Rowwise Scaling
for torch 2.5+
from torchao.quantization import quantize_, PerRow, Float8DynamicActivationFloat8WeightConfig
quantize_(model, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
Per-row scaling is only supported for bfloat16 weight and activation. This API is only tested on H100. Hardware with CUDA compute capability 8.9 or greater is required.
which said "CUDA compute capability 8.9 or greater is required.".But actually, I found that PerRow() needs CUDA compute capability >=9.0, as in the code
File "/opt/conda/lib/python3.11/site-packages/torchao/quantization/quant_api.py", line 1475, in _normalize_granularity
assert is_sm_at_least_90() or is_MI300(), (
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: PerRow quantization only works for CUDA>=9.0 and MI300+
I use torchao==0.11.0, so is there a typo mistake or the code was wrong?
Metadata
Metadata
Assignees
Labels
No labels