-
Notifications
You must be signed in to change notification settings - Fork 290
[CPU] Enable DA8W4 on CPU #2128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
0581451
dffbbab
9fb7f77
35ece3b
c5b6d87
8e80d03
51249c3
3e20172
e765664
4feac3f
0d85183
c42abdb
e2815ce
8c5eebb
2a26e15
369000f
f6e87ba
0a87ef0
e05b96a
fd6e4b1
18335c6
66ab77f
2c5a799
131660e
75fbd6f
24268fd
4c0a739
0815d96
e3731f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,7 +147,7 @@ def to(self, *args, **kwargs): | |
device = kwargs["device"] | ||
if not is_device(torch.device(self.device).type, device): | ||
raise ValueError( | ||
f"Int4CPUAQTTensorImpl does not support conversion from {self.device} to {device}" | ||
f"{self.__class__.__name__} does not support conversion from {self.device} to {device}" | ||
) | ||
return self.__class__( | ||
self.packed_weight.to(device), | ||
|
@@ -214,11 +214,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): | |
return return_and_correct_aliasing(func, args, kwargs, sliced) | ||
else: | ||
raise NotImplementedError( | ||
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" | ||
f"{cls.__name__} dispatch: attempting to run {func}, with dim={dim}, that is not supported" | ||
) | ||
|
||
raise NotImplementedError( | ||
f"Int4CPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported" | ||
f"{cls.__name__} dispatch: attempting to run {func}, this is not supported" | ||
) | ||
|
||
__torch_function__ = torch._C._disabled_torch_function_impl | ||
|
@@ -352,3 +352,134 @@ def _linear_fp_act_uint4_weight_cpu_impl(input_tensor, weight_tensor, bias): | |
if bias is not None: | ||
y += bias | ||
return y.to(orig_dtype) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class Int8DynamicActInt4WeightCPULayout(Layout): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it looks like you can just reuse Int4CPULayout There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you move the layout and impl to a separate file? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. Done. |
||
"""Layout class for da8w4 CPU layout for affine quantized tensor""" | ||
|
||
pass | ||
|
||
|
||
@register_layout(Int8DynamicActInt4WeightCPULayout) | ||
class DA8W4CPUAQTTensorImpl(Int4CPUAQTTensorImpl): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh I see, OK if you need a separate Impl then makes sense to have a separate layout There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. We need a different impl from W16W4 because the ISA (AMX and VNNI) requires different memory formats of weight for computation in BF16 or INT8. Thanks. |
||
"""TensorImpl for da8w4 CPU layout for affine quantized tensor | ||
It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of | ||
dimension: [n][k / 2] (uint8 dtype) | ||
It is similar to Int4CPUAQTTensorImpl but with a different memory layout of weight data | ||
fields: | ||
packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout | ||
scales (torch.Tensor): the scales Tensor used to map between floating point tensor to quantized tensor | ||
qzeros (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor | ||
""" | ||
|
||
def __new__( | ||
cls, | ||
packed_weight: torch.Tensor, | ||
scales: torch.Tensor, | ||
qzeros: torch.Tensor, | ||
transposed: bool, | ||
_layout: Layout, | ||
): | ||
kwargs = {} | ||
kwargs["device"] = packed_weight.device | ||
kwargs["layout"] = ( | ||
kwargs.get("layout") | ||
if kwargs.get("layout", False) | ||
else packed_weight.layout | ||
) | ||
kwargs["dtype"] = packed_weight.dtype | ||
kwargs["requires_grad"] = False | ||
shape = packed_weight.shape | ||
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] | ||
|
||
def __init__( | ||
self, | ||
packed_weight: torch.Tensor, | ||
scales: torch.Tensor, | ||
qzeros: torch.Tensor, | ||
transposed: bool, | ||
_layout: Layout, | ||
): | ||
self.packed_weight = packed_weight | ||
self.scales = scales | ||
self.qzeros = qzeros | ||
self.transposed = transposed | ||
self._layout = _layout | ||
|
||
def __tensor_flatten__(self): | ||
return ["packed_weight", "scales", "qzeros"], [self.transposed, self._layout] | ||
|
||
@classmethod | ||
def __tensor_unflatten__( | ||
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride | ||
): | ||
packed_weight, scales, qzeros = ( | ||
tensor_data_dict["packed_weight"], | ||
tensor_data_dict["scales"], | ||
tensor_data_dict["qzeros"], | ||
) | ||
( | ||
transposed, | ||
_layout, | ||
) = tensor_attributes | ||
return cls(packed_weight, scales, qzeros, transposed, _layout) | ||
|
||
@classmethod | ||
def from_plain( | ||
cls, | ||
int_data: torch.Tensor, | ||
scale: torch.Tensor, | ||
zero_point: Optional[torch.Tensor], | ||
_layout: Layout, | ||
): | ||
assert isinstance(_layout, Int8DynamicActInt4WeightCPULayout) | ||
assert int_data.dtype == torch.int8, "DA8W4 CPU: expects int8 weight" | ||
assert int_data.shape[1] % 2 == 0, "DA8W4 CPU: expects even number of columns" | ||
weight_int4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) | ||
return cls(weight_int4, scale, zero_point, False, _layout) | ||
|
||
def _apply_fn_to_data(self, fn): | ||
return self.__class__( | ||
fn(self.packed_weight), | ||
fn(self.scales), | ||
fn(self.qzeros), | ||
self.transposed, | ||
self._layout, | ||
) | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs): | ||
kwargs = {} if kwargs is None else kwargs | ||
if func is aten.t.default: | ||
"""we don't need to repack the weight and just rely on external | ||
shape being changed and record the status of transpose/no-transpose | ||
""" | ||
transposed = DA8W4CPUAQTTensorImpl( | ||
args[0].packed_weight, | ||
args[0].scales, | ||
args[0].qzeros, | ||
not args[0].transposed, | ||
args[0]._layout, | ||
Xia-Weiwen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
return return_and_correct_aliasing(func, args, kwargs, transposed) | ||
else: | ||
return super().__torch_dispatch__(func, types, args, kwargs) | ||
|
||
__torch_function__ = torch._C._disabled_torch_function_impl | ||
|
||
@property | ||
def block_size(self): | ||
assert len(self.packed_weight.shape) == 2 | ||
weight_shape = self.packed_weight.shape | ||
N = weight_shape[0] | ||
K = weight_shape[1] * 2 | ||
groups = self.scales.numel() // N | ||
group_size = K // groups | ||
return (1, group_size) | ||
|
||
def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
plain_weight = torch.stack( | ||
((self.packed_weight << 4) >> 4, self.packed_weight >> 4), dim=-1 | ||
).view(self.packed_weight.shape[:-1] + (2 * self.packed_weight.shape[-1],)) | ||
return plain_weight, self.scales, self.qzeros |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you use the new API:
Int8DynamicActivationInt4WeightConfig
instead ofint8_dynamic_activation_int4_weight
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Done.