Skip to content

[CPU] Enable DA8W4 on CPU #2128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
0581451
[CPU] enable int8_dynamic_activation_int4_weight with Int4CPULayout
Xia-Weiwen Apr 25, 2025
dffbbab
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Apr 25, 2025
9fb7f77
Fix format issue
Xia-Weiwen Apr 25, 2025
35ece3b
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Apr 28, 2025
c5b6d87
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen May 12, 2025
8e80d03
Add Int8DynamicActInt4WeightCPULayout
Xia-Weiwen May 14, 2025
51249c3
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen May 15, 2025
3e20172
remove dispatch for t()
Xia-Weiwen May 16, 2025
e765664
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen May 21, 2025
4feac3f
Add cpp kernel for weight packing and GEMM
Xia-Weiwen May 23, 2025
0d85183
Register ATQ linear dispatch for da8w4 linear
Xia-Weiwen May 25, 2025
c42abdb
Fix issues with torch.compile
Xia-Weiwen May 26, 2025
e2815ce
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen May 26, 2025
8c5eebb
Fix DA8W4CPUAQTTensorImpl.get_plain
Xia-Weiwen May 26, 2025
2a26e15
Test DA8W4CPUAQTTensorImpl.get_plain in UT
Xia-Weiwen May 26, 2025
369000f
Skip UT if CPP kernel not built
Xia-Weiwen May 26, 2025
f6e87ba
Add AVX512_VNNI implementation for small M
Xia-Weiwen May 27, 2025
0a87ef0
improve performance
Xia-Weiwen Jun 3, 2025
e05b96a
Support symmetric quantization of activation
Xia-Weiwen Jun 4, 2025
fd6e4b1
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Jun 4, 2025
18335c6
Refine code
Xia-Weiwen Jun 4, 2025
66ab77f
Refine code
Xia-Weiwen Jun 5, 2025
2c5a799
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Jun 5, 2025
131660e
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Jun 11, 2025
75fbd6f
Put in a separate file
Xia-Weiwen Jun 14, 2025
24268fd
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Jun 14, 2025
4c0a739
Bug fix
Xia-Weiwen Jun 25, 2025
0815d96
Merge branch 'main' into da8w4_with_int4_cpu_layout
Xia-Weiwen Jun 25, 2025
e3731f7
refine code
Xia-Weiwen Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
AffineQuantizedTensor,
Int4CPULayout,
Int4XPULayout,
Int8DynamicActInt4WeightCPULayout,
PlainLayout,
QDQLayout,
TensorCoreTiledLayout,
Expand Down Expand Up @@ -875,6 +876,43 @@ def test_int4wo_cpu(self, dtype, x_dim, use_hqq):
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+")
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
@common_utils.parametrize("x_dim", [2, 3])
def test_8da4w_cpu(self, dtype, x_dim):
device = "cpu"
m = ToyLinearModel().eval().to(dtype).to(device)
m2 = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=dtype, device=device)
if x_dim == 3:
example_inputs = (example_inputs[0].unsqueeze(0),)

with torch.no_grad():
# Currently, the difference between Int8DynamicActInt4WeightCPULayout and PlainLayout
# is that the former packs two int4 weights into one int8, while the latter does not.
quantize_(
m,
int8_dynamic_activation_int4_weight(
group_size=32, layout=Int8DynamicActInt4WeightCPULayout()
),
)
y, code = torch._inductor.utils.run_and_get_code(
torch.compile(m, fullgraph=True, dynamic=True),
*example_inputs,
)
# ensure the expected op is in the code
assert "shift" in code[0] # unpacking int4 values
assert "extern_kernels.mm" in code[0]
quantize_(
m2,
int8_dynamic_activation_int4_weight(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you use the new API: Int8DynamicActivationInt4WeightConfig instead of int8_dynamic_activation_int4_weight?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Done.

group_size=32, layout=PlainLayout()
),
)
torch._dynamo.reset() # may segfault without this
y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs)
assert torch.allclose(y, y2)

# TODO(#1690): move to new config names
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CutlassInt4PackedLayout,
Int4CPULayout,
Int4XPULayout,
Int8DynamicActInt4WeightCPULayout,
MarlinQQQLayout,
MarlinQQQTensor,
MarlinSparseLayout,
Expand Down Expand Up @@ -61,4 +62,5 @@
"PackedLinearInt8DynamicActivationIntxWeightLayout",
"to_affine_quantized_packed_linear_int8_dynamic_activation_intx_weight",
"Int4XPULayout",
"Int8DynamicActInt4WeightCPULayout",
]
2 changes: 2 additions & 0 deletions torchao/dtypes/uintx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
)
from .int4_cpu_layout import (
Int4CPULayout,
Int8DynamicActInt4WeightCPULayout,
)
from .int4_xpu_layout import (
Int4XPULayout,
Expand Down Expand Up @@ -48,4 +49,5 @@
"PackedLinearInt8DynamicActivationIntxWeightLayout",
"QDQLayout",
"Int4XPULayout",
"Int8DynamicActInt4WeightCPULayout",
]
137 changes: 134 additions & 3 deletions torchao/dtypes/uintx/int4_cpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def to(self, *args, **kwargs):
device = kwargs["device"]
if not is_device(torch.device(self.device).type, device):
raise ValueError(
f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}"
f"{self.__class__.__name__} does not support conversion from {self.device} to {device}"
)
return self.__class__(
self.packed_weight.to(device),
Expand Down Expand Up @@ -214,11 +214,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, sliced)
else:
raise NotImplementedError(
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported"
f"{cls.__name__} dispatch: attempting to run {func}, with dim={dim}, that is not supported"
)

raise NotImplementedError(
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
f"{cls.__name__} dispatch: attempting to run {func}, this is not supported"
)

__torch_function__ = torch._C._disabled_torch_function_impl
Expand Down Expand Up @@ -352,3 +352,134 @@ def _linear_fp_act_uint4_weight_cpu_impl(input_tensor, weight_tensor, bias):
if bias is not None:
y += bias
return y.to(orig_dtype)


@dataclass(frozen=True)
class Int8DynamicActInt4WeightCPULayout(Layout):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like you can just reuse Int4CPULayout

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move the layout and impl to a separate file?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Done.

"""Layout class for da8w4 CPU layout for affine quantized tensor"""

pass


@register_layout(Int8DynamicActInt4WeightCPULayout)
class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, OK if you need a separate Impl then makes sense to have a separate layout

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We need a different impl from W16W4 because the ISA (AMX and VNNI) requires different memory formats of weight for computation in BF16 or INT8. Thanks.

"""TensorImpl for da8w4 CPU layout for affine quantized tensor
It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of
dimension: [n][k / 2] (uint8 dtype)
It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data
fields:
packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout
scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor
qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor
"""

def __new__(
cls,
packed_weight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
transposed: bool,
_layout: Layout,
):
kwargs = {}
kwargs["device"] = packed_weight.device
kwargs["layout"] = (
kwargs.get("layout")
if kwargs.get("layout", False)
else packed_weight.layout
)
kwargs["dtype"] = packed_weight.dtype
kwargs["requires_grad"] = False
shape = packed_weight.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(
self,
packed_weight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
transposed: bool,
_layout: Layout,
):
self.packed_weight = packed_weight
self.scales = scales
self.qzeros = qzeros
self.transposed = transposed
self._layout = _layout

def __tensor_flatten__(self):
return ["packed_weight", "scales", "qzeros"], [self.transposed, self._layout]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_weight, scales, qzeros = (
tensor_data_dict["packed_weight"],
tensor_data_dict["scales"],
tensor_data_dict["qzeros"],
)
(
transposed,
_layout,
) = tensor_attributes
return cls(packed_weight, scales, qzeros, transposed, _layout)

@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout)
assert int_data.dtype == torch.int8, "DA8W4 CPU: expects int8 weight"
assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns"
weight_int4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF)
return cls(weight_int4, scale, zero_point, False, _layout)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.packed_weight),
fn(self.scales),
fn(self.qzeros),
self.transposed,
self._layout,
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
if func is aten.t.default:
"""we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
"""
transposed = DA8W4CPUAQTTensorImpl(
args[0].packed_weight,
args[0].scales,
args[0].qzeros,
not args[0].transposed,
args[0]._layout,
)
return return_and_correct_aliasing(func, args, kwargs, transposed)
else:
return super().__torch_dispatch__(func, types, args, kwargs)

__torch_function__ = torch._C._disabled_torch_function_impl

@property
def block_size(self):
assert len(self.packed_weight.shape) == 2
weight_shape = self.packed_weight.shape
N = weight_shape[0]
K = weight_shape[1] * 2
groups = self.scales.numel() // N
group_size = K // groups
return (1, group_size)

def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
plain_weight = torch.stack(
((self.packed_weight << 4) >> 4, self.packed_weight >> 4), dim=-1
).view(self.packed_weight.shape[:-1] + (2 * self.packed_weight.shape[-1],))
return plain_weight, self.scales, self.qzeros
Loading