Skip to content

Commit 5f20d3d

Browse files
committed
Fix QAT range learning, ensure scales get gradients
TBD
1 parent dd43f16 commit 5f20d3d

File tree

5 files changed

+49
-124
lines changed

5 files changed

+49
-124
lines changed

test/quantization/test_qat.py

Lines changed: 14 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from torchao.quantization.qat.utils import (
4747
_fake_quantize_per_channel_group,
4848
_fake_quantize_per_token,
49-
_GenericFakeQuantize,
5049
_get_qmin_qmax,
5150
)
5251
from torchao.quantization.quant_api import (
@@ -582,42 +581,6 @@ def test_qat_8da4w_quantizer_gradients(self):
582581
quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=16)
583582
self._test_qat_quantized_gradients(quantizer)
584583

585-
@unittest.skipIf(
586-
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
587-
)
588-
def test_qat_generic_fake_quantize(self):
589-
"""
590-
Test that the generic fake quantize used in 8da4w QAT matches
591-
the numerics of existing fake quantize ops in Pytorch in both
592-
the forward and the backward passes.
593-
"""
594-
(qmin, qmax) = _get_qmin_qmax(4)
595-
py_input = torch.randn(16, 64).float().requires_grad_()
596-
py_s = torch.randn(16).float()
597-
py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32)
598-
py_out = torch.fake_quantize_per_channel_affine(
599-
py_input, py_s, py_zp, 0, qmin, qmax
600-
)
601-
py_out.sum().backward()
602-
603-
ao_input = copy.deepcopy(py_input)
604-
ao_input.grad.data.zero_()
605-
block_size = (1, ao_input.shape[-1])
606-
ao_s = copy.deepcopy(py_s)
607-
ao_zp = copy.deepcopy(py_zp)
608-
ao_out = _GenericFakeQuantize.apply(
609-
ao_input, block_size, ao_s, ao_zp, qmin, qmax
610-
)
611-
ao_out.sum().backward()
612-
613-
torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
614-
615-
# Test that gradients are close enough
616-
num_grads = py_input.grad.numel()
617-
num_equal_grads = torch.eq(py_input.grad, ao_input.grad).flatten().sum().item()
618-
num_equal_grad_threshold = 0.8
619-
self.assertGreaterEqual(num_equal_grads / num_grads, num_equal_grad_threshold)
620-
621584
def _assert_close_4w(self, val, ref):
622585
# Note: for int4 weight-only quantization, we do not expect exact match
623586
# because torch._weight_int4pack_mm and torch.mm do not match exactly.
@@ -1697,16 +1660,24 @@ def test_qat_range_learning(self):
16971660
m(*example_inputs)
16981661

16991662
# Simulate training
1663+
num_steps = 10
17001664
optimizer = torch.optim.SGD(
17011665
m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5
17021666
)
17031667
loss_fn = torch.nn.CrossEntropyLoss()
1704-
target = torch.randn(1, 512).float()
1705-
out = m(*example_inputs)
1706-
loss = loss_fn(out, target)
1707-
optimizer.zero_grad()
1708-
loss.backward()
1709-
optimizer.step()
1668+
for i in range(num_steps):
1669+
prev_scale = copy.deepcopy(m.linear1.weight_fake_quantizer.scale)
1670+
optimizer.zero_grad()
1671+
target = torch.randn(1, 512).float()
1672+
out = m(*example_inputs)
1673+
loss = loss_fn(out, target)
1674+
loss.backward()
1675+
optimizer.step()
1676+
# Assert that scales have valid gradients and are being updated
1677+
new_scale = m.linear1.weight_fake_quantizer.scale
1678+
self.assertIsNotNone(new_scale.grad)
1679+
self.assertNotEqual(torch.count_nonzero(new_scale.grad), 0)
1680+
self.assertFalse(torch.equal(new_scale, prev_scale))
17101681

17111682

17121683
if __name__ == "__main__":

torchao/quantization/qat/affine_fake_quantized_tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
choose_qparams_affine,
1717
choose_qparams_affine_dont_preserve_zero,
1818
choose_qparams_affine_tinygemm,
19+
fake_quantize_affine,
1920
)
2021
from torchao.utils import TorchAOBaseTensor
2122

2223
from .utils import (
23-
_GenericFakeQuantize,
2424
_UnwrapAffineFakeQuantizedTensor,
2525
)
2626

@@ -90,11 +90,12 @@ def apply_fake_quant_fn(t: torch.Tensor):
9090
scale_dtype,
9191
zero_point_dtype,
9292
)
93-
fq = _GenericFakeQuantize.apply(
93+
fq = fake_quantize_affine(
9494
t,
9595
block_size,
9696
scale,
9797
zero_point,
98+
torch.int32,
9899
qmin,
99100
qmax,
100101
zero_point_domain,

torchao/quantization/qat/fake_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_DTYPE_TO_BIT_WIDTH,
1818
_DTYPE_TO_QVALUE_BOUNDS,
1919
MappingType,
20+
_Round,
2021
choose_qparams_affine,
2122
)
2223
from torchao.quantization.utils import (
@@ -31,7 +32,6 @@
3132
from .utils import (
3233
_fake_quantize_per_channel_group,
3334
_fake_quantize_per_token,
34-
_Round,
3535
)
3636

3737

torchao/quantization/qat/utils.py

Lines changed: 6 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,68 +4,19 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import List
87

98
import torch
109

1110
from torchao.quantization.quant_primitives import (
1211
ZeroPointDomain,
13-
fake_quantize_affine_cachemask,
12+
fake_quantize_affine,
1413
)
1514
from torchao.quantization.utils import (
1615
_get_per_token_block_size,
1716
)
1817

1918

20-
class _GenericFakeQuantize(torch.autograd.Function):
21-
"""
22-
Implementation of generic fake quantize with backward STE.
23-
24-
With the appropriate input tensor shape, this can be used to express
25-
grouped per channel fake quantize or per token fake quantize.
26-
"""
27-
28-
@staticmethod
29-
def forward(
30-
ctx: torch.autograd.function.FunctionCtx,
31-
input: torch.Tensor,
32-
block_size: List[int],
33-
scales: torch.Tensor,
34-
zero_points: torch.Tensor,
35-
quant_min: int,
36-
quant_max: int,
37-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
38-
) -> torch.Tensor:
39-
# avoid circular dependencies
40-
from torchao.quantization.qat.affine_fake_quantized_tensor import (
41-
AffineFakeQuantizedTensor,
42-
)
43-
44-
if isinstance(input, AffineFakeQuantizedTensor):
45-
_input = input.original_tensor
46-
else:
47-
_input = input
48-
49-
(fq, mask) = fake_quantize_affine_cachemask(
50-
_input,
51-
block_size,
52-
scales,
53-
zero_points,
54-
torch.int32,
55-
quant_min,
56-
quant_max,
57-
zero_point_domain,
58-
)
59-
60-
ctx.save_for_backward(mask)
61-
return fq
62-
63-
@staticmethod
64-
def backward(ctx, gy):
65-
(mask,) = ctx.saved_tensors
66-
return gy * mask, None, None, None, None, None, None
67-
68-
19+
# TODO: delete?
6920
class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function):
7021
"""
7122
Helper autograd function to unwrap `AffineFakeQuantizedTensor` while ensuring
@@ -91,20 +42,6 @@ def backward(ctx, gy):
9142
return (gy,)
9243

9344

94-
class _Round(torch.autograd.Function):
95-
"""
96-
Implementation of generic round operation with backward STE.
97-
"""
98-
99-
@staticmethod
100-
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
101-
return torch.round(x)
102-
103-
@staticmethod
104-
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
105-
return gy
106-
107-
10845
def _fake_quantize_per_channel_group(
10946
input: torch.Tensor,
11047
scales: torch.Tensor,
@@ -118,11 +55,12 @@ def _fake_quantize_per_channel_group(
11855
assert input.shape[-1] % group_size == 0
11956
assert input.dim() == 2
12057
block_size = (1, group_size)
121-
return _GenericFakeQuantize.apply(
58+
return fake_quantize_affine(
12259
input,
12360
block_size,
12461
scales,
12562
zero_points,
63+
torch.int32,
12664
quant_min,
12765
quant_max,
12866
zero_point_domain,
@@ -140,11 +78,12 @@ def _fake_quantize_per_token(
14078

14179
_per_token_quant_qparam_dim_check(input, scales, zero_points)
14280
block_size = _get_per_token_block_size(input)
143-
fq = _GenericFakeQuantize.apply(
81+
fq = fake_quantize_affine(
14482
input,
14583
block_size,
14684
scales,
14785
zero_points,
86+
torch.int32,
14887
quant_min,
14988
quant_max,
15089
)

torchao/quantization/quant_primitives.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,20 @@ class TorchAODType(Enum):
212212
register_custom_op = _register_custom_op(quant_lib)
213213

214214

215+
class _Round(torch.autograd.Function):
216+
"""
217+
Implementation of generic round operation with backward STE.
218+
"""
219+
220+
@staticmethod
221+
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
222+
return torch.round(x)
223+
224+
@staticmethod
225+
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
226+
return gy
227+
228+
215229
# TODO: decide on if we want to allow custom quant_min/quant_max here
216230
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
217231
"""Get quant_min and quant_max args based on dtype and also
@@ -407,7 +421,7 @@ def _quantize_affine_no_dtype_cast(
407421
zero_point = None
408422

409423
quant = torch.clamp(
410-
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
424+
_Round.apply(input * (1.0 / scale)) + zero_point, quant_min, quant_max
411425
)
412426
quant = quant.view(original_shape)
413427

@@ -493,7 +507,7 @@ def _quantize_affine_float_zero_point_no_dtype_cast(
493507

494508
mid_point = (quant_max + quant_min + 1) / 2
495509
min_val = zero_point - scale * mid_point
496-
quant = torch.clamp(torch.round((input - min_val) / scale), quant_min, quant_max)
510+
quant = torch.clamp(_Round.apply((input - min_val) / scale), quant_min, quant_max)
497511
quant = quant.view(original_shape)
498512

499513
return quant
@@ -577,7 +591,7 @@ def _quantize_affine_no_zero_point_no_dtype_cast(
577591
# with numel=0 which we handle by unifying the two
578592
zero_point = None
579593

580-
quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max)
594+
quant = torch.clamp(_Round.apply(input * (1.0 / scale)), quant_min, quant_max)
581595
quant = quant.view(original_shape)
582596

583597
return quant
@@ -1202,7 +1216,7 @@ def choose_qparams_affine_dont_preserve_zero(
12021216
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
12031217
scale = torch.clamp(scale, min=eps)
12041218
# Zero point is int
1205-
zero_point = quant_min - torch.round(min_val_neg / scale)
1219+
zero_point = quant_min - _Round.apply(min_val_neg / scale)
12061220
zero_point = torch.clamp(zero_point, quant_min, quant_max)
12071221
if zero_point_dtype is None:
12081222
zero_point_dtype = torch.int32
@@ -1308,7 +1322,7 @@ def choose_qparams_affine_with_min_max(
13081322
if zero_point_domain == ZeroPointDomain.NONE:
13091323
zero_point = None
13101324
elif zero_point_domain == ZeroPointDomain.INT:
1311-
zero_point = quant_min - torch.round(min_val_neg / scale)
1325+
zero_point = quant_min - _Round.apply(min_val_neg / scale)
13121326
zero_point = torch.clamp(zero_point, quant_min, quant_max)
13131327
if zero_point_dtype is None:
13141328
zero_point_dtype = torch.int32
@@ -1400,7 +1414,7 @@ def _choose_qparams_affine(
14001414
assert mapping_type == MappingType.ASYMMETRIC.name
14011415
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
14021416
scale = torch.clamp(scale, min=eps)
1403-
zero_point = quant_min - torch.round(min_val_neg / scale)
1417+
zero_point = quant_min - _Round.apply(min_val_neg / scale)
14041418
zero_point = torch.clamp(zero_point, quant_min, quant_max)
14051419
if zero_point_dtype is None:
14061420
zero_point_dtype = torch.int32
@@ -1434,7 +1448,7 @@ def choose_qparams_and_quantize_affine_qqq(
14341448
s_group *= 2 / max_q_val # 2 => symmetric
14351449

14361450
# Quantize
1437-
q_w = torch.round(w / s_group).int()
1451+
q_w = _Round.apply(w / s_group).int()
14381452
q_w += half_q_val
14391453
q_w = torch.clamp(q_w, 0, max_q_val)
14401454
# Compute ref (dequantized)
@@ -1467,7 +1481,7 @@ def reshape_w(w):
14671481
s_channel /= max_q_val
14681482

14691483
# Quantize
1470-
q_w = torch.round(w / s_channel).int()
1484+
q_w = _Round.apply(w / s_channel).int()
14711485
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
14721486
# Compute ref (dequantized)
14731487
w_ref = q_w.half() * s_channel
@@ -1871,7 +1885,7 @@ def choose_qparams_and_quantize_affine_hqq(
18711885

18721886
# Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14
18731887
if nbits in [4]:
1874-
zero = torch.round(zero)
1888+
zero = _Round.apply(zero)
18751889

18761890
# Fine-tune weights
18771891
if optimize:
@@ -1887,7 +1901,7 @@ def choose_qparams_and_quantize_affine_hqq(
18871901
else:
18881902
zero = zero.to(compute_dtype)
18891903
scale = scale.to(compute_dtype)
1890-
W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])
1904+
W_q = _Round.apply(W * scale + zero).clamp(min_max[0], min_max[1])
18911905

18921906
# Store meta-data (we invert the scale for dequantization)
18931907
scale = 1.0 / scale
@@ -2004,7 +2018,7 @@ def choose_qparams_affine_float8(
20042018
if scale_dtype is not torch.float32:
20052019
# Shielding for Version > 2.8
20062020
assert scale_dtype is torch.float8_e8m0fnu, "Only float8_e8m0fnuz is supported"
2007-
scale = torch.exp2(torch.round(torch.log2(scale)))
2021+
scale = torch.exp2(_Round.apply(torch.log2(scale)))
20082022
return scale.to(dtype=torch.float32)
20092023

20102024

0 commit comments

Comments
 (0)