Skip to content

Commit 92a01f5

Browse files
committed
Fix Per Row scaling for inference
stack-info: PR: #2253, branch: drisspg/stack/56
1 parent a776b1f commit 92a01f5

File tree

6 files changed

+306
-162
lines changed

6 files changed

+306
-162
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -297,21 +297,55 @@ def test_fp8_weight_dimension_warning(self):
297297
@unittest.skipIf(
298298
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
299299
)
300-
def test_mm_float8dq(self):
300+
@common_utils.parametrize(
301+
"in_features,out_features", [(512, 1024), (256, 768), (1024, 512)]
302+
)
303+
@common_utils.parametrize(
304+
"leading_shape", [(1,), (8,), (16,), (2, 8,), (2, 2, 16,)]
305+
) # fmt: skip
306+
@common_utils.parametrize("bias", [True, False])
307+
def test_mm_float8dq_per_row(
308+
self, in_features, out_features, leading_shape, bias: bool
309+
):
301310
device = "cuda"
302311
dtype = torch.bfloat16
303-
weight = torch.randn(512, 1024).to(device).to(dtype)
304-
weight = weight.t()
305-
306-
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
307-
l.weight = torch.nn.Parameter(weight)
308-
quantize_(l, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()))
309-
# weight shape: 1024 x 512
310-
weight = l.weight
311-
312-
input = torch.randn(1, 512, device=device, dtype=dtype)
313-
# make sure it runs
314-
torch.nn.functional.linear(input, weight)
312+
input_shape = leading_shape + (in_features,)
313+
314+
ref_linear = (
315+
torch.nn.Linear(in_features, out_features, bias=bias).to(device).to(dtype)
316+
)
317+
test_linear = copy.deepcopy(ref_linear)
318+
quantize_(
319+
test_linear, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
320+
)
321+
322+
quant_weight = test_linear.weight
323+
324+
self.assertTrue(hasattr(quant_weight, "original_weight_tensor"))
325+
weight_impl = quant_weight.original_weight_tensor.tensor_impl
326+
327+
self.assertTrue(hasattr(weight_impl, "float8_data"))
328+
self.assertTrue(hasattr(weight_impl, "scale"))
329+
self.assertFalse(weight_impl.transposed)
330+
331+
# Verify scale shape for row-wise quantization
332+
expected_scale_shape = (out_features, 1)
333+
actual_scale_shape = weight_impl.scale.shape
334+
self.assertEqual(actual_scale_shape, expected_scale_shape)
335+
336+
self.assertEqual(weight_impl.float8_data.shape, (out_features, in_features))
337+
338+
input_tensor = torch.randn(*input_shape, device=device, dtype=dtype)
339+
340+
with torch.no_grad():
341+
ref_output = ref_linear(input_tensor)
342+
quant_output = torch.nn.functional.linear(input_tensor, quant_weight)
343+
344+
expected_output_shape = input_tensor.shape[:-1] + (out_features,)
345+
self.assertEqual(quant_output.shape, expected_output_shape)
346+
347+
error = compute_error(ref_output, quant_output)
348+
assert error > 20, f"Quantization error is too high got a SQNR of {error}"
315349

316350

317351
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,10 +462,10 @@ def from_hp_to_floatx(
462462
if target_dtype in FP8_TYPES:
463463
original_shape = input_float.shape
464464
input_float = _layout.pre_process(input_float)
465-
466-
scale = choose_qparams_affine_float8(input_float, float8_dtype=target_dtype)
465+
scale = choose_qparams_affine_float8(
466+
input_float, float8_dtype=target_dtype, block_size=block_size
467+
)
467468
data = quantize_affine_float8(input_float, scale, target_dtype)
468-
469469
data, scale, zero_point = _layout.post_process(
470470
data, scale, None, block_size
471471
)
@@ -503,7 +503,6 @@ def from_hp_to_floatx_static(
503503
input_float,
504504
scale,
505505
target_dtype,
506-
scale_dtype,
507506
)
508507

509508
data, scale, zero_point = _layout.post_process(

torchao/dtypes/floatx/float8_layout.py

Lines changed: 122 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
from dataclasses import dataclass
7-
from typing import Optional, Tuple, Union
7+
from typing import Any, Dict, List, Optional, Tuple, Union
88

99
import torch
1010
from torch.utils._python_dispatch import (
@@ -26,14 +26,25 @@
2626
from torchao.utils import _is_float8_type, fill_defaults
2727

2828
aten = torch.ops.aten
29+
FLOAT8_IMPL_OPS_TABLE: Dict[Any, Any] = {}
30+
31+
32+
def implements(aten_ops: List[Any]):
33+
"""Register aten ops to the float8 op table"""
34+
35+
def decorator(func):
36+
for op in aten_ops:
37+
FLOAT8_IMPL_OPS_TABLE[op] = func
38+
return func
39+
40+
return decorator
2941

3042

3143
def _same_metadata(self: "Float8AQTTensorImpl", src: "Float8AQTTensorImpl") -> bool:
3244
# Special handling for transposed attribute
3345
transposed_match = (self.transposed == src.transposed) or (
3446
self.transposed is False and src.transposed is None
3547
)
36-
3748
return (
3849
isinstance(self, Float8AQTTensorImpl)
3950
and isinstance(src, Float8AQTTensorImpl)
@@ -160,90 +171,23 @@ def __tensor_unflatten__(
160171
def __torch_dispatch__(cls, func, types, args, kwargs):
161172
kwargs = {} if kwargs is None else kwargs
162173

163-
if func is aten.detach.default:
164-
return return_and_correct_aliasing(
165-
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
166-
)
167-
elif func is aten.clone.default:
168-
return return_and_correct_aliasing(
169-
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
170-
)
171-
elif func is aten.t.default:
172-
"""we don't need to repack the weight and just rely on external
173-
shape being changed and record the status of transpose/no-transpose
174-
"""
175-
args[0].transposed = not args[0].transposed
176-
return return_and_correct_aliasing(func, args, kwargs, args[0])
177-
elif func is aten.copy_.default:
178-
self = args[0]
179-
src = args[1]
180-
if _same_metadata(self, src):
181-
self_tensors = self.__tensor_flatten__()[0]
182-
for tensor_name in self_tensors:
183-
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
184-
return
185-
raise ValueError(
186-
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
187-
)
188-
elif func in [aten.select.int, aten.index.Tensor]:
189-
return return_and_correct_aliasing(
190-
func,
191-
args,
192-
kwargs,
193-
args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)),
194-
)
195-
elif func is aten.slice.Tensor:
196-
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
197-
if dim == 0:
198-
# TODO: scale replecation should be dependent on block size
199-
if self.scale.ndim == 1:
200-
return return_and_correct_aliasing(
201-
func,
202-
args,
203-
kwargs,
204-
args[0]._apply_fn_to_data(
205-
lambda x: aten.slice.Tensor(x, dim, start, end, step)
206-
),
207-
)
208-
elif self.scale.ndim == 0:
209-
return return_and_correct_aliasing(
210-
func,
211-
args,
212-
kwargs,
213-
Float8AQTTensorImpl(
214-
aten.slice.Tensor(self.float8_data, dim, start, end, step),
215-
self.scale,
216-
None,
217-
self._layout,
218-
),
219-
)
220-
else:
221-
raise NotImplementedError(
222-
f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported"
223-
)
224-
elif dim == 1:
225-
return return_and_correct_aliasing(
226-
func,
227-
args,
228-
kwargs,
229-
Float8AQTTensorImpl(
230-
aten.slice.Tensor(
231-
self.float8_data, dim, start, end, step
232-
).contiguous(),
233-
self.scale,
234-
None,
235-
self._layout,
236-
),
237-
)
238-
else:
239-
raise NotImplementedError(
240-
f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
174+
def allowed_subclasses(type):
175+
return (
176+
issubclass(cls, type)
177+
or issubclass(torch._subclasses.fake_tensor.FakeTensor, type)
178+
or issubclass(
179+
torch._subclasses.functional_tensor.FunctionalTensor, type
241180
)
242-
else:
243-
raise NotImplementedError(
244-
f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported"
245181
)
246182

183+
if not all(allowed_subclasses(t) for t in types):
184+
return NotImplemented
185+
186+
if func in FLOAT8_IMPL_OPS_TABLE:
187+
return FLOAT8_IMPL_OPS_TABLE[func](func, types, args, kwargs)
188+
189+
raise NotImplementedError(f"attempting to run {func}, this is not supported")
190+
247191
__torch_function__ = torch._C._disabled_torch_function_impl
248192

249193
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
@@ -281,6 +225,100 @@ def __repr__(self):
281225
)
282226

283227

228+
##########################
229+
# Regsiter FP8 Ops
230+
##########################
231+
232+
233+
@implements([aten.detach.default, aten.alias.default, aten.clone.default])
234+
def _(func, types, args, kwargs):
235+
return return_and_correct_aliasing(
236+
func, args, kwargs, args[0]._apply_fn_to_data(func)
237+
)
238+
239+
240+
@implements([aten.t.default])
241+
def _(func, types, args, kwargs):
242+
"""we don't need to repack the weight and just rely on external
243+
shape being changed and record the status of transpose/no-transpose
244+
"""
245+
args[0].transposed = not args[0].transposed
246+
return return_and_correct_aliasing(func, args, kwargs, args[0])
247+
248+
249+
@implements([aten.copy_.default])
250+
def _(func, types, args, kwargs):
251+
self = args[0]
252+
src = args[1]
253+
if _same_metadata(self, src):
254+
self_tensors = self.__tensor_flatten__()[0]
255+
for tensor_name in self_tensors:
256+
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
257+
return
258+
raise ValueError(
259+
f"Not supported args for copy_ due to metadata mismatch: {args[0], args[1]}"
260+
)
261+
262+
263+
@implements([aten.select.int, aten.index.Tensor])
264+
def _(func, types, args, kwargs):
265+
return return_and_correct_aliasing(
266+
func,
267+
args,
268+
kwargs,
269+
args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)),
270+
)
271+
272+
273+
@implements([aten.slice.Tensor])
274+
def _(func, types, args, kwargs):
275+
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
276+
if dim == 0:
277+
if self.scale.numel() == 1:
278+
# Per Tensor
279+
return return_and_correct_aliasing(
280+
func,
281+
args,
282+
kwargs,
283+
Float8AQTTensorImpl(
284+
aten.slice.Tensor(self.float8_data, dim, start, end, step),
285+
self.scale,
286+
self.transposed,
287+
self._layout,
288+
),
289+
)
290+
elif self.scale.ndim == 2:
291+
# TODO: scale replication should be dependent on block size
292+
return return_and_correct_aliasing(
293+
func,
294+
args,
295+
kwargs,
296+
args[0]._apply_fn_to_data(
297+
lambda x: aten.slice.Tensor(x, dim, start, end, step)
298+
),
299+
)
300+
else:
301+
raise NotImplementedError(
302+
f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported"
303+
)
304+
elif dim == 1:
305+
return return_and_correct_aliasing(
306+
func,
307+
args,
308+
kwargs,
309+
Float8AQTTensorImpl(
310+
aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(),
311+
self.scale,
312+
self.transposed,
313+
self._layout,
314+
),
315+
)
316+
else:
317+
raise NotImplementedError(
318+
f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
319+
)
320+
321+
284322
##########################
285323
# Float8 Dispatch Kernels
286324
##########################
@@ -333,13 +371,12 @@ def _linear_fp8_act_fp8_weight_impl(
333371
input_scale = input_tensor.tensor_impl.scale
334372
# Handle case where input tensor is more than 2D
335373
inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1])
336-
337374
# Handle rowwise case
338375
if _is_rowwise_scaled(weight_tensor):
339376
assert _is_rowwise_scaled(input_tensor), (
340377
"Input tensor must be rowwise block size"
341378
)
342-
w_scale = w_scale.unsqueeze(-1).T
379+
w_scale = w_scale.T
343380
input_scale = preprocess_scale(input_scale, input_tensor.shape)
344381

345382
# Preprocess data

0 commit comments

Comments
 (0)