Skip to content

Commit f1b7738

Browse files
authored
Temp fix to unblock diff train
Differential Revision: D75966594 Pull Request resolved: #11361
1 parent 35754d1 commit f1b7738

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

backends/vulkan/_passes/int4_weight_only_quantizer.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,31 @@
77
import torch
88
import torch.nn.functional as F
99

10-
from torchao.quantization.GPTQ.GPTQ import _check_linear_int4_k
1110
from torchao.quantization.unified import Quantizer
1211
from torchao.quantization.utils import groupwise_affine_quantize_tensor
1312

1413

14+
# TODO: import from from torchao.quantization.GPTQ.GPTQ import _check_linear_int4_k
15+
# Once diff train catches up
16+
def _check_linear_int4_k(k, group_size=1, inner_k_tiles=None):
17+
"""
18+
Check if the dimensions are compatible with int4 quantization.
19+
20+
Args:
21+
k: The dimension size to check
22+
group_size: The group size for quantization
23+
inner_k_tiles: The inner k tiles size
24+
25+
Returns:
26+
bool: Whether the dimensions are compatible
27+
"""
28+
k_divisible_by_group_size = k % group_size == 0
29+
if inner_k_tiles is not None:
30+
k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0
31+
return k_divisible_by_group_size and k_divisible_by_16_times_inner_k_tiles
32+
return k_divisible_by_group_size
33+
34+
1535
# This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with
1636
# changes at the annotated lines.
1737
class VkWeightOnlyInt4Linear(torch.nn.Module):

0 commit comments

Comments
 (0)