Skip to content

Commit c4ad425

Browse files
committed
Fix QAT range learning, ensure scales get gradients
**Summary:** The previous `_GenericFakeQuantized` nulled all gradients except the ones for the input. This is problematic for range learning because scales and zero points are now `nn.Parameters` and actually require gradients. This commit fixes this by reducing the scope of the `autograd.Function` to `torch.round` only, so QAT can just call the fake quantization primitives directly. Note: Part of the dequantize math currently casts the inputs and the zero points to int32. However, autograd doesn't work with integer math and this part of the code path is now visible to autograd. To make this work, this commit also removes this dtype cast. Note: This change means we no longer do cachemask and so our numerics no longer matches those of pytorch/pytorch's fake quantization ops. **Test Plan:** Updated the following test to check for scales and weights being updated: python test/quantization/test_qat.py -k test_qat_range_learning
1 parent dd43f16 commit c4ad425

File tree

5 files changed

+65
-135
lines changed

5 files changed

+65
-135
lines changed

test/quantization/test_qat.py

Lines changed: 20 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from torchao.quantization.qat.utils import (
4747
_fake_quantize_per_channel_group,
4848
_fake_quantize_per_token,
49-
_GenericFakeQuantize,
5049
_get_qmin_qmax,
5150
)
5251
from torchao.quantization.quant_api import (
@@ -582,42 +581,6 @@ def test_qat_8da4w_quantizer_gradients(self):
582581
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
583582
self._test_qat_quantized_gradients(quantizer)
584583

585-
@unittest.skipIf(
586-
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
587-
)
588-
def test_qat_generic_fake_quantize(self):
589-
"""
590-
Test that the generic fake quantize used in 8da4w QAT matches
591-
the numerics of existing fake quantize ops in Pytorch in both
592-
the forward and the backward passes.
593-
"""
594-
(qmin, qmax) = _get_qmin_qmax(4)
595-
py_input = torch.randn(16, 64).float().requires_grad_()
596-
py_s = torch.randn(16).float()
597-
py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32)
598-
py_out = torch.fake_quantize_per_channel_affine(
599-
py_input, py_s, py_zp, 0, qmin, qmax
600-
)
601-
py_out.sum().backward()
602-
603-
ao_input = copy.deepcopy(py_input)
604-
ao_input.grad.data.zero_()
605-
block_size = (1, ao_input.shape[-1])
606-
ao_s = copy.deepcopy(py_s)
607-
ao_zp = copy.deepcopy(py_zp)
608-
ao_out = _GenericFakeQuantize.apply(
609-
ao_input, block_size, ao_s, ao_zp, qmin, qmax
610-
)
611-
ao_out.sum().backward()
612-
613-
torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
614-
615-
# Test that gradients are close enough
616-
num_grads = py_input.grad.numel()
617-
num_equal_grads = torch.eq(py_input.grad, ao_input.grad).flatten().sum().item()
618-
num_equal_grad_threshold = 0.8
619-
self.assertGreaterEqual(num_equal_grads / num_grads, num_equal_grad_threshold)
620-
621584
def _assert_close_4w(self, val, ref):
622585
# Note: for int4 weight-only quantization, we do not expect exact match
623586
# because torch._weight_int4pack_mm and torch.mm do not match exactly.
@@ -1697,16 +1660,30 @@ def test_qat_range_learning(self):
16971660
m(*example_inputs)
16981661

16991662
# Simulate training
1663+
num_steps = 10
17001664
optimizer = torch.optim.SGD(
17011665
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
17021666
)
17031667
loss_fn = torch.nn.CrossEntropyLoss()
1704-
target = torch.randn(1, 512).float()
1705-
out = m(*example_inputs)
1706-
loss = loss_fn(out, target)
1707-
optimizer.zero_grad()
1708-
loss.backward()
1709-
optimizer.step()
1668+
for i in range(num_steps):
1669+
prev_scale = copy.deepcopy(m.linear1.weight_fake_quantizer.scale)
1670+
prev_weight = copy.deepcopy(m.linear1.weight)
1671+
optimizer.zero_grad()
1672+
target = torch.randn(1, 512).float()
1673+
out = m(*example_inputs)
1674+
loss = loss_fn(out, target)
1675+
loss.backward()
1676+
optimizer.step()
1677+
# Assert that scales have valid gradients and are being updated
1678+
new_scale = m.linear1.weight_fake_quantizer.scale
1679+
self.assertIsNotNone(new_scale.grad)
1680+
self.assertNotEqual(torch.count_nonzero(new_scale.grad), 0)
1681+
self.assertFalse(torch.equal(new_scale, prev_scale))
1682+
# Assert that weights have valid gradients and are being updated
1683+
new_weight = m.linear1.weight
1684+
self.assertIsNotNone(new_weight.grad)
1685+
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
1686+
self.assertFalse(torch.equal(new_weight, prev_weight))
17101687

17111688

17121689
if __name__ == "__main__":

torchao/quantization/qat/affine_fake_quantized_tensor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
choose_qparams_affine,
1717
choose_qparams_affine_dont_preserve_zero,
1818
choose_qparams_affine_tinygemm,
19+
fake_quantize_affine,
1920
)
2021
from torchao.utils import TorchAOBaseTensor
2122

2223
from .utils import (
23-
_GenericFakeQuantize,
2424
_UnwrapAffineFakeQuantizedTensor,
2525
)
2626

@@ -90,14 +90,15 @@ def apply_fake_quant_fn(t: torch.Tensor):
9090
scale_dtype,
9191
zero_point_dtype,
9292
)
93-
fq = _GenericFakeQuantize.apply(
93+
fq = fake_quantize_affine(
9494
t,
9595
block_size,
9696
scale,
9797
zero_point,
98-
qmin,
99-
qmax,
100-
zero_point_domain,
98+
quant_dtype=torch.int32,
99+
quant_min=qmin,
100+
quant_max=qmax,
101+
zero_point_domain=zero_point_domain,
101102
)
102103
return fq
103104

torchao/quantization/qat/fake_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_DTYPE_TO_BIT_WIDTH,
1818
_DTYPE_TO_QVALUE_BOUNDS,
1919
MappingType,
20+
_Round,
2021
choose_qparams_affine,
2122
)
2223
from torchao.quantization.utils import (
@@ -31,7 +32,6 @@
3132
from .utils import (
3233
_fake_quantize_per_channel_group,
3334
_fake_quantize_per_token,
34-
_Round,
3535
)
3636

3737

torchao/quantization/qat/utils.py

Lines changed: 11 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -4,68 +4,19 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List
87

98
import torch
109

1110
from torchao.quantization.quant_primitives import (
1211
ZeroPointDomain,
13-
fake_quantize_affine_cachemask,
12+
fake_quantize_affine,
1413
)
1514
from torchao.quantization.utils import (
1615
_get_per_token_block_size,
1716
)
1817

1918

20-
class _GenericFakeQuantize(torch.autograd.Function):
21-
"""
22-
Implementation of generic fake quantize with backward STE.
23-
24-
With the appropriate input tensor shape, this can be used to express
25-
grouped per channel fake quantize or per token fake quantize.
26-
"""
27-
28-
@staticmethod
29-
def forward(
30-
ctx: torch.autograd.function.FunctionCtx,
31-
input: torch.Tensor,
32-
block_size: List[int],
33-
scales: torch.Tensor,
34-
zero_points: torch.Tensor,
35-
quant_min: int,
36-
quant_max: int,
37-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
38-
) -> torch.Tensor:
39-
# avoid circular dependencies
40-
from torchao.quantization.qat.affine_fake_quantized_tensor import (
41-
AffineFakeQuantizedTensor,
42-
)
43-
44-
if isinstance(input, AffineFakeQuantizedTensor):
45-
_input = input.original_tensor
46-
else:
47-
_input = input
48-
49-
(fq, mask) = fake_quantize_affine_cachemask(
50-
_input,
51-
block_size,
52-
scales,
53-
zero_points,
54-
torch.int32,
55-
quant_min,
56-
quant_max,
57-
zero_point_domain,
58-
)
59-
60-
ctx.save_for_backward(mask)
61-
return fq
62-
63-
@staticmethod
64-
def backward(ctx, gy):
65-
(mask,) = ctx.saved_tensors
66-
return gy * mask, None, None, None, None, None, None
67-
68-
19+
# TODO: delete?
6920
class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function):
7021
"""
7122
Helper autograd function to unwrap `AffineFakeQuantizedTensor` while ensuring
@@ -91,20 +42,6 @@ def backward(ctx, gy):
9142
return (gy,)
9243

9344

94-
class _Round(torch.autograd.Function):
95-
"""
96-
Implementation of generic round operation with backward STE.
97-
"""
98-
99-
@staticmethod
100-
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
101-
return torch.round(x)
102-
103-
@staticmethod
104-
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
105-
return gy
106-
107-
10845
def _fake_quantize_per_channel_group(
10946
input: torch.Tensor,
11047
scales: torch.Tensor,
@@ -118,14 +55,15 @@ def _fake_quantize_per_channel_group(
11855
assert input.shape[-1] % group_size == 0
11956
assert input.dim() == 2
12057
block_size = (1, group_size)
121-
return _GenericFakeQuantize.apply(
58+
return fake_quantize_affine(
12259
input,
12360
block_size,
12461
scales,
12562
zero_points,
126-
quant_min,
127-
quant_max,
128-
zero_point_domain,
63+
quant_dtype=torch.int32,
64+
quant_min=quant_min,
65+
quant_max=quant_max,
66+
zero_point_domain=zero_point_domain,
12967
)
13068

13169

@@ -140,13 +78,14 @@ def _fake_quantize_per_token(
14078

14179
_per_token_quant_qparam_dim_check(input, scales, zero_points)
14280
block_size = _get_per_token_block_size(input)
143-
fq = _GenericFakeQuantize.apply(
81+
fq = fake_quantize_affine(
14482
input,
14583
block_size,
14684
scales,
14785
zero_points,
148-
quant_min,
149-
quant_max,
86+
quant_dtype=torch.int32,
87+
quant_min=quant_min,
88+
quant_max=quant_max,
15089
)
15190
return fq.reshape_as(input).to(input.dtype)
15291

torchao/quantization/quant_primitives.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,20 @@ class TorchAODType(Enum):
212212
register_custom_op = _register_custom_op(quant_lib)
213213

214214

215+
class _Round(torch.autograd.Function):
216+
"""
217+
Implementation of generic round operation with backward STE.
218+
"""
219+
220+
@staticmethod
221+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
222+
return torch.round(x)
223+
224+
@staticmethod
225+
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
226+
return gy
227+
228+
215229
# TODO: decide on if we want to allow custom quant_min/quant_max here
216230
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
217231
"""Get quant_min and quant_max args based on dtype and also
@@ -407,7 +421,7 @@ def _quantize_affine_no_dtype_cast(
407421
zero_point = None
408422

409423
quant = torch.clamp(
410-
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
424+
_Round.apply(input * (1.0 / scale)) + zero_point, quant_min, quant_max
411425
)
412426
quant = quant.view(original_shape)
413427

@@ -493,7 +507,7 @@ def _quantize_affine_float_zero_point_no_dtype_cast(
493507

494508
mid_point = (quant_max + quant_min + 1) / 2
495509
min_val = zero_point - scale * mid_point
496-
quant = torch.clamp(torch.round((input - min_val) / scale), quant_min, quant_max)
510+
quant = torch.clamp(_Round.apply((input - min_val) / scale), quant_min, quant_max)
497511
quant = quant.view(original_shape)
498512

499513
return quant
@@ -577,7 +591,7 @@ def _quantize_affine_no_zero_point_no_dtype_cast(
577591
# with numel=0 which we handle by unifying the two
578592
zero_point = None
579593

580-
quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max)
594+
quant = torch.clamp(_Round.apply(input * (1.0 / scale)), quant_min, quant_max)
581595
quant = quant.view(original_shape)
582596

583597
return quant
@@ -692,10 +706,9 @@ def _dequantize_affine_no_dtype_check(
692706

693707
# Force a copy to avoid input modification due
694708
# to upcoming in-place operations.
695-
dequant = input.to(torch.int32, copy=True)
709+
dequant = input.to(output_dtype, copy=True)
696710
if zero_point is not None:
697-
dequant = dequant - zero_point.to(torch.int32)
698-
dequant = dequant.to(output_dtype)
711+
dequant = dequant - zero_point.to(output_dtype)
699712
dequant = dequant * scale
700713

701714
return dequant.view(original_shape).to(output_dtype)
@@ -1202,7 +1215,7 @@ def choose_qparams_affine_dont_preserve_zero(
12021215
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
12031216
scale = torch.clamp(scale, min=eps)
12041217
# Zero point is int
1205-
zero_point = quant_min - torch.round(min_val_neg / scale)
1218+
zero_point = quant_min - _Round.apply(min_val_neg / scale)
12061219
zero_point = torch.clamp(zero_point, quant_min, quant_max)
12071220
if zero_point_dtype is None:
12081221
zero_point_dtype = torch.int32
@@ -1308,7 +1321,7 @@ def choose_qparams_affine_with_min_max(
13081321
if zero_point_domain == ZeroPointDomain.NONE:
13091322
zero_point = None
13101323
elif zero_point_domain == ZeroPointDomain.INT:
1311-
zero_point = quant_min - torch.round(min_val_neg / scale)
1324+
zero_point = quant_min - _Round.apply(min_val_neg / scale)
13121325
zero_point = torch.clamp(zero_point, quant_min, quant_max)
13131326
if zero_point_dtype is None:
13141327
zero_point_dtype = torch.int32
@@ -1400,7 +1413,7 @@ def _choose_qparams_affine(
14001413
assert mapping_type == MappingType.ASYMMETRIC.name
14011414
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
14021415
scale = torch.clamp(scale, min=eps)
1403-
zero_point = quant_min - torch.round(min_val_neg / scale)
1416+
zero_point = quant_min - _Round.apply(min_val_neg / scale)
14041417
zero_point = torch.clamp(zero_point, quant_min, quant_max)
14051418
if zero_point_dtype is None:
14061419
zero_point_dtype = torch.int32
@@ -1434,7 +1447,7 @@ def choose_qparams_and_quantize_affine_qqq(
14341447
s_group *= 2 / max_q_val # 2 => symmetric
14351448

14361449
# Quantize
1437-
q_w = torch.round(w / s_group).int()
1450+
q_w = _Round.apply(w / s_group).int()
14381451
q_w += half_q_val
14391452
q_w = torch.clamp(q_w, 0, max_q_val)
14401453
# Compute ref (dequantized)
@@ -1467,7 +1480,7 @@ def reshape_w(w):
14671480
s_channel /= max_q_val
14681481

14691482
# Quantize
1470-
q_w = torch.round(w / s_channel).int()
1483+
q_w = _Round.apply(w / s_channel).int()
14711484
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
14721485
# Compute ref (dequantized)
14731486
w_ref = q_w.half() * s_channel
@@ -1871,7 +1884,7 @@ def choose_qparams_and_quantize_affine_hqq(
18711884

18721885
# Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14
18731886
if nbits in [4]:
1874-
zero = torch.round(zero)
1887+
zero = _Round.apply(zero)
18751888

18761889
# Fine-tune weights
18771890
if optimize:
@@ -1887,7 +1900,7 @@ def choose_qparams_and_quantize_affine_hqq(
18871900
else:
18881901
zero = zero.to(compute_dtype)
18891902
scale = scale.to(compute_dtype)
1890-
W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])
1903+
W_q = _Round.apply(W * scale + zero).clamp(min_max[0], min_max[1])
18911904

18921905
# Store meta-data (we invert the scale for dequantization)
18931906
scale = 1.0 / scale
@@ -2004,7 +2017,7 @@ def choose_qparams_affine_float8(
20042017
if scale_dtype is not torch.float32:
20052018
# Shielding for Version > 2.8
20062019
assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported"
2007-
scale = torch.exp2(torch.round(torch.log2(scale)))
2020+
scale = torch.exp2(_Round.apply(torch.log2(scale)))
20082021
return scale.to(dtype=torch.float32)
20092022

20102023

0 commit comments

Comments
 (0)