Skip to content

Enable AWQ on Intel GPU. #2248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
if TORCH_VERSION_AT_LEAST_2_5:
if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
if (check_xpu_version(w.device)):
w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8)

return w_int4x8

Expand Down Expand Up @@ -730,6 +732,8 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
not (check_xpu_version(input.device))
):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
if (check_xpu_version(input.device)):
input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain
)
Expand Down
9 changes: 5 additions & 4 deletions torchao/dtypes/uintx/int4_xpu_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,15 @@ def from_plain(
):
assert isinstance(_layout, Int4XPULayout)

from torchao.quantization.utils import convert_weight_to_int4pack_xpu

if TORCH_VERSION_AT_LEAST_2_8:
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
)
packed_weight = convert_weight_to_int4pack_xpu(
int_data, zero_point.dtype != scale.dtype
packed_weight = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(
torch.uint8
)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
packed_weight.contiguous(), 8
)
else:
assert False, "INT4 not supported on XPU until 2.8"
Expand Down
16 changes: 12 additions & 4 deletions torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import types
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Tuple, Union
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from typing import Any, Callable, Dict, Optional, Tuple, Union
from typing import Optional

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


import torch

Expand All @@ -13,6 +14,7 @@
from torchao.dtypes import (
TensorCoreTiledLayout,
to_affine_quantized_intx,
Int4XPULayout,
)
from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
Expand Down Expand Up @@ -114,6 +116,7 @@ class AWQUIntXConfig(AOBaseConfig):
group_size: int = 64
use_hqq: bool = False
set_inductor_config: bool = True
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.FLOAT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be removed if we have layout?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I agree with you. Following the logic of #2149, preserve_zero and zero_point_domain is too complex to be used in the user UX. It is better way to use layout to decide the zero_point_domain information.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, modified done.



# for bc
Expand All @@ -135,16 +138,21 @@ def _awq_uintx_transform(
assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, (
"Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"
)


device = observed_linear.weight.device
equalization_scale = observed_linear.act_obs.calculate_qparams()
# AQT config
if quant_dtype == torch.uint4:
target_dtype = torch.int32
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
zero_point_dtype = torch.bfloat16 if config.zero_point_domain != ZeroPointDomain.INT else torch.int8
zero_point_domain = config.zero_point_domain

if "xpu" in device.type:
_layout = Int4XPULayout()
else:
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can layout be explicitly passed in instead of inferred from device?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be OK. We should follow the Int4WeightOnlyConfig to let user to specify the layout information.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, modified done.

else:
target_dtype = torch.uint8
eps = torch.finfo(torch.float32).eps
Expand Down
28 changes: 26 additions & 2 deletions torchao/prototype/awq/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@

from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_
from torchao.quantization import int4_weight_only, quantize_
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
)
from torchao.dtypes import Int4XPULayout


zero_point_domain_dict = {"float":ZeroPointDomain.FLOAT, "int":ZeroPointDomain.INT, "none":ZeroPointDomain.NONE}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, we used to use this for distinguish between different types of kernels, but now we are keeping the default path of integer zero point and preserve zero for the common path, and split out the other q/dq ops for specific kernels like tinygemm: #2149

I think it's just different ways to implement things and we not necessarily need to have these categorizations like zero_point_domain and preserve_zero since it might complicate the UX.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, modified done.



# adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255
Expand Down Expand Up @@ -71,6 +78,8 @@ def wiki2_eval(
log_likelihood = model(input_ids, labels=target_ids).loss * trg_len
if device.startswith("cuda"):
torch.cuda.synchronize()
if device.startswith("xpu"):
torch.xpu.synchronize()
t2 = time.time()
t.append((t2 - t1))
lls.append(log_likelihood)
Expand Down Expand Up @@ -190,6 +199,7 @@ def wikitext2_ppl(
precision: torch.dtype,
sequence_length: int,
compile: bool,
zero_point_domin: str,
model_save_path: str,
):
print(f"Loading model on {device}...")
Expand Down Expand Up @@ -231,8 +241,9 @@ def wikitext2_ppl(
t0 = time.time()
quantize_(
model,
awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq),
awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq, zero_point_domain=zero_point_domain_dict[zero_point_domin]),
is_observed_linear,
torch.device(device),
)
print(f"time for quantization: {time.time() - t0:.02f} seconds")
if model_save_path is not None:
Expand All @@ -242,10 +253,15 @@ def wikitext2_ppl(
group_size = int(quant.split("-")[1])
use_hqq = "hqq" in quant
print(f"running {quant} quantization with group size {group_size}")
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
int4_weight_only_config = int4_weight_only(group_size=group_size, use_hqq=use_hqq)
if "xpu" in device:
int4_weight_only_config.layout = Int4XPULayout()
int4_weight_only_config.layout.zero_point_domin = zero_point_domain_dict["zero_point_domin"]
quantize_(model, int4_weight_only_config)
if compile:
model = torch.compile(model)

print("model:", model)
return benchmark(model, tokenizer, sequence_length, tasks=tasks, device=device)


Expand Down Expand Up @@ -299,6 +315,13 @@ def wikitext2_ppl(
action="store_true",
help="Flag to indicate if compilation is required.",
)
parser.add_argument(
"--zero_point_domin",
type=str,
default="float",
choices=['float', 'int', 'none'],
help="Zero point type. Default is 'float'.",
)
parser.add_argument(
"--model_save_path",
type=str,
Expand All @@ -320,6 +343,7 @@ def wikitext2_ppl(
args.precision,
args.seq_len,
args.compile,
args.zero_point_domin,
args.model_save_path,
)

Expand Down
7 changes: 0 additions & 7 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,13 +697,6 @@ def to_qtensor_components(
int_data = aten._convert_weight_to_int4pack_for_cpu(
input_int4x8, inner_k_tiles
)
if check_xpu_version(input_float.device):
from torchao.quantization.utils import convert_weight_to_int4pack_xpu

int_data = convert_weight_to_int4pack_xpu(
input_int4x8,
zero_point_domain_is_int=zero_point_domain == ZeroPointDomain.INT,
)
else:
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles
36 changes: 14 additions & 22 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ def cuda(self):
val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values
]

def xpu(self):
self.values = [
val.xpu() if isinstance(val, torch.Tensor) else val for val in self.values
]


def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None):
if dtype is not None and tensor_arg.dtype != dtype:
Expand Down Expand Up @@ -415,25 +420,6 @@ def unpack_tinygemm_scales_and_zeros(scales_and_zeros):
return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1)


def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False):
assert weight.device.type == "xpu"

if zero_point_domain_is_int:
# int_data = weight.to(dtype=torch.uint8)
int_data = (weight[::, 1::2] << 4 | weight[::, ::2]).to(torch.uint8)
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
int_data,
8, # TODO:remove
)
else:
out = weight.to(dtype=torch.uint8)
out = (out[::, 1::2] << 4 | out[::, ::2]).to(torch.uint8)
packed_weight = out.view(torch.int32)

# Second, N * K/2 uint8 -> N * K/8 int32
return packed_weight


def groupwise_affine_quantize_tensor_from_qparams(
w, scales, zeros, n_bit=4, groupsize=128, zero_point_domain=ZeroPointDomain.FLOAT
):
Expand Down Expand Up @@ -473,6 +459,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
not (check_xpu_version(int_data.device))
):
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
if check_xpu_version(int_data.device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should probably encapsulate these better when we have a better design for layout conversions: #2249

int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
return int_data


Expand All @@ -491,7 +479,6 @@ def groupwise_affine_dequantize_tensor_from_qparams(
TORCH_VERSION_AT_LEAST_2_5
and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1)
and not (check_cpu_version(w_int4x8.device))
and not (check_xpu_version(w_int4x8.device))
):
data = w_int4x8.to(torch.int32)
high_bits = data >> 4
Expand All @@ -501,8 +488,13 @@ def groupwise_affine_dequantize_tensor_from_qparams(
dtype=torch.int32,
device=w_int4x8.device,
)
w_int32[::, ::2] = high_bits
w_int32[::, 1::2] = low_bits
if (not (check_xpu_version(w_int4x8.device))
):
w_int32[::, ::2] = high_bits
w_int32[::, 1::2] = low_bits
else:
w_int32[::, ::2] = low_bits
w_int32[::, 1::2] = high_bits
else:
w_int32 = w_int4x8

Expand Down