diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 046fb6ab42..e69d68b27f 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -135,6 +135,8 @@ def _groupwise_affine_quantize_tensor_from_qparams( if TORCH_VERSION_AT_LEAST_2_5: if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))): w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) + if check_xpu_version(w.device): + w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8) return w_int4x8 @@ -730,6 +732,8 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): not (check_xpu_version(input.device)) ): input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) + if check_xpu_version(input.device): + input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8) w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain ) diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py index 76f5ecb121..722a37bc32 100644 --- a/torchao/dtypes/uintx/int4_xpu_layout.py +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -242,14 +242,15 @@ def from_plain( ): assert isinstance(_layout, Int4XPULayout) - from torchao.quantization.utils import convert_weight_to_int4pack_xpu - if TORCH_VERSION_AT_LEAST_2_8: assert int_data.dtype == torch.int32, ( "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" ) - packed_weight = convert_weight_to_int4pack_xpu( - int_data, zero_point.dtype != scale.dtype + packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to( + torch.uint8 + ) + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + packed_weight.contiguous(), 8 ) else: assert False, "INT4 not supported on XPU until 2.8" @@ -370,8 +371,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs): def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( - ZeroPointDomain, quantize_affine, + quantize_affine_float_zero_point, ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros @@ -394,7 +395,6 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quant_max = 15 assert len(block_size) == 2 and block_size[0] == 1 if self.scale_and_zero is None: - zero_point_domain = ZeroPointDomain.INT dequantized = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, @@ -411,10 +411,8 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: target_dtype, quant_min, quant_max, - zero_point_domain, ) else: - zero_point_domain = ZeroPointDomain.FLOAT dequantized = torch.ops.aten._weight_int4pack_mm( torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, @@ -425,7 +423,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine( + int_data = quantize_affine_float_zero_point( dequantized, block_size, scale, @@ -433,7 +431,6 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: target_dtype, quant_min, quant_max, - zero_point_domain, ) return int_data, scale, zero diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 2d6059c057..5806c29ce6 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -5,12 +5,15 @@ # LICENSE file in the root directory of this source tree. import types from dataclasses import dataclass +from typing import Optional import torch import torchao from torchao.core.config import AOBaseConfig from torchao.dtypes import ( + Int4XPULayout, + Layout, TensorCoreTiledLayout, to_affine_quantized_intx, ) @@ -105,12 +108,14 @@ class AWQUIntXConfig(AOBaseConfig): Args: quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8 + `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)` group_size: Quantization granularity. Use -1 for channel wise quantization weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used set_inductor_config: if True, adjusts `torchinductor` settings to recommended values. """ quant_dtype: torch.dtype = torch.uint4 + layout: Optional[Layout] = TensorCoreTiledLayout(inner_k_tiles=8) group_size: int = 64 use_hqq: bool = False set_inductor_config: bool = True @@ -142,9 +147,13 @@ def _awq_uintx_transform( target_dtype = torch.int32 eps = 1e-6 preserve_zero = False - zero_point_dtype = torch.bfloat16 - zero_point_domain = ZeroPointDomain.FLOAT - _layout = TensorCoreTiledLayout(inner_k_tiles=8) + _layout = config.layout + if isinstance(_layout, Int4XPULayout): + zero_point_dtype = torch.int8 + zero_point_domain = ZeroPointDomain.INT + else: + zero_point_dtype = torch.bfloat16 + zero_point_domain = ZeroPointDomain.FLOAT else: target_dtype = torch.uint8 eps = torch.finfo(torch.float32).eps diff --git a/torchao/prototype/awq/example.py b/torchao/prototype/awq/example.py index ba1ba6834c..7ff6092b05 100644 --- a/torchao/prototype/awq/example.py +++ b/torchao/prototype/awq/example.py @@ -11,6 +11,7 @@ from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer +from torchao.dtypes import Int4XPULayout from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ from torchao.quantization import int4_weight_only, quantize_ @@ -71,6 +72,8 @@ def wiki2_eval( log_likelihood = model(input_ids, labels=target_ids).loss * trg_len if device.startswith("cuda"): torch.cuda.synchronize() + if device.startswith("xpu"): + torch.xpu.synchronize() t2 = time.time() t.append((t2 - t1)) lls.append(log_likelihood) @@ -229,9 +232,14 @@ def wikitext2_ppl( use_hqq = "hqq" in quant print(f"running {quant_dtype} quantization") t0 = time.time() + awq_uintx_config = awq_uintx( + quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq + ) + if "xpu" in device: + awq_uintx_config.layout = Int4XPULayout() quantize_( model, - awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq), + awq_uintx_config, is_observed_linear, ) print(f"time for quantization: {time.time() - t0:.02f} seconds") @@ -242,7 +250,12 @@ def wikitext2_ppl( group_size = int(quant.split("-")[1]) use_hqq = "hqq" in quant print(f"running {quant} quantization with group size {group_size}") - quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) + int4_weight_only_config = int4_weight_only( + group_size=group_size, use_hqq=use_hqq + ) + if "xpu" in device: + int4_weight_only_config.layout = Int4XPULayout() + quantize_(model, int4_weight_only_config) if compile: model = torch.compile(model) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index abaad317eb..be0533510f 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -697,13 +697,6 @@ def to_qtensor_components( int_data = aten._convert_weight_to_int4pack_for_cpu( input_int4x8, inner_k_tiles ) - if check_xpu_version(input_float.device): - from torchao.quantization.utils import convert_weight_to_int4pack_xpu - - int_data = convert_weight_to_int4pack_xpu( - input_int4x8, - zero_point_domain_is_int=zero_point_domain == ZeroPointDomain.INT, - ) else: int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 30b9980878..8f2554849c 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -127,6 +127,11 @@ def cuda(self): val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values ] + def xpu(self): + self.values = [ + val.xpu() if isinstance(val, torch.Tensor) else val for val in self.values + ] + def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None): if dtype is not None and tensor_arg.dtype != dtype: @@ -415,25 +420,6 @@ def unpack_tinygemm_scales_and_zeros(scales_and_zeros): return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1) -def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False): - assert weight.device.type == "xpu" - - if zero_point_domain_is_int: - # int_data = weight.to(dtype=torch.uint8) - int_data = (weight[::, 1::2] << 4 | weight[::, ::2]).to(torch.uint8) - packed_weight = torch.ops.aten._convert_weight_to_int4pack( - int_data, - 8, # TODO:remove - ) - else: - out = weight.to(dtype=torch.uint8) - out = (out[::, 1::2] << 4 | out[::, ::2]).to(torch.uint8) - packed_weight = out.view(torch.int32) - - # Second, N * K/2 uint8 -> N * K/8 int32 - return packed_weight - - def groupwise_affine_quantize_tensor_from_qparams( w, scales, zeros, n_bit=4, groupsize=128, zero_point_domain=ZeroPointDomain.FLOAT ): @@ -473,6 +459,8 @@ def groupwise_affine_quantize_tensor_from_qparams( not (check_xpu_version(int_data.device)) ): int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) + if check_xpu_version(int_data.device): + int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) return int_data @@ -491,7 +479,6 @@ def groupwise_affine_dequantize_tensor_from_qparams( TORCH_VERSION_AT_LEAST_2_5 and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not (check_cpu_version(w_int4x8.device)) - and not (check_xpu_version(w_int4x8.device)) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4 @@ -501,8 +488,12 @@ def groupwise_affine_dequantize_tensor_from_qparams( dtype=torch.int32, device=w_int4x8.device, ) - w_int32[::, ::2] = high_bits - w_int32[::, 1::2] = low_bits + if not (check_xpu_version(w_int4x8.device)): + w_int32[::, ::2] = high_bits + w_int32[::, 1::2] = low_bits + else: + w_int32[::, ::2] = low_bits + w_int32[::, 1::2] = high_bits else: w_int32 = w_int4x8