-
Notifications
You must be signed in to change notification settings - Fork 272
Enable AWQ on Intel GPU. #2248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Enable AWQ on Intel GPU. #2248
Changes from 2 commits
46113f8
5184c99
5831514
2b1d077
9fc41ad
1087198
bf469c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
# LICENSE file in the root directory of this source tree. | ||
import types | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict, Optional, Tuple, Union | ||
|
||
import torch | ||
|
||
|
@@ -13,6 +14,7 @@ | |
from torchao.dtypes import ( | ||
TensorCoreTiledLayout, | ||
to_affine_quantized_intx, | ||
Int4XPULayout, | ||
) | ||
from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout | ||
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata | ||
|
@@ -114,6 +116,7 @@ class AWQUIntXConfig(AOBaseConfig): | |
group_size: int = 64 | ||
use_hqq: bool = False | ||
set_inductor_config: bool = True | ||
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.FLOAT | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this be removed if we have layout? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I agree with you. Following the logic of #2149, preserve_zero and zero_point_domain is too complex to be used in the user UX. It is better way to use layout to decide the zero_point_domain information. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, modified done. |
||
|
||
|
||
# for bc | ||
|
@@ -135,16 +138,21 @@ def _awq_uintx_transform( | |
assert quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8, ( | ||
"Invalid quant_dtype. Please use torch.uint1 .. torch.uint8" | ||
) | ||
|
||
|
||
device = observed_linear.weight.device | ||
equalization_scale = observed_linear.act_obs.calculate_qparams() | ||
# AQT config | ||
if quant_dtype == torch.uint4: | ||
target_dtype = torch.int32 | ||
eps = 1e-6 | ||
preserve_zero = False | ||
zero_point_dtype = torch.bfloat16 | ||
zero_point_domain = ZeroPointDomain.FLOAT | ||
_layout = TensorCoreTiledLayout(inner_k_tiles=8) | ||
zero_point_dtype = torch.bfloat16 if config.zero_point_domain != ZeroPointDomain.INT else torch.int8 | ||
zero_point_domain = config.zero_point_domain | ||
|
||
if "xpu" in device.type: | ||
_layout = Int4XPULayout() | ||
else: | ||
_layout = TensorCoreTiledLayout(inner_k_tiles=8) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can layout be explicitly passed in instead of inferred from device? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be OK. We should follow the Int4WeightOnlyConfig to let user to specify the layout information. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, modified done. |
||
else: | ||
target_dtype = torch.uint8 | ||
eps = torch.finfo(torch.float32).eps | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,13 @@ | |
|
||
from torchao.prototype.awq import AWQObservedLinear, awq_uintx, insert_awq_observer_ | ||
from torchao.quantization import int4_weight_only, quantize_ | ||
from torchao.quantization.quant_primitives import ( | ||
ZeroPointDomain, | ||
) | ||
from torchao.dtypes import Int4XPULayout | ||
|
||
|
||
zero_point_domain_dict = {"float":ZeroPointDomain.FLOAT, "int":ZeroPointDomain.INT, "none":ZeroPointDomain.NONE} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI, we used to use this for distinguish between different types of kernels, but now we are keeping the default path of integer zero point and preserve zero for the common path, and split out the other q/dq ops for specific kernels like tinygemm: #2149 I think it's just different ways to implement things and we not necessarily need to have these categorizations like zero_point_domain and preserve_zero since it might complicate the UX. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, modified done. |
||
|
||
|
||
# adapted from: https://github.com/mit-han-lab/llm-awq/blob/main/awq/entry.py#L255 | ||
|
@@ -71,6 +78,8 @@ def wiki2_eval( | |
log_likelihood = model(input_ids, labels=target_ids).loss * trg_len | ||
if device.startswith("cuda"): | ||
torch.cuda.synchronize() | ||
if device.startswith("xpu"): | ||
torch.xpu.synchronize() | ||
t2 = time.time() | ||
t.append((t2 - t1)) | ||
lls.append(log_likelihood) | ||
|
@@ -190,6 +199,7 @@ def wikitext2_ppl( | |
precision: torch.dtype, | ||
sequence_length: int, | ||
compile: bool, | ||
zero_point_domin: str, | ||
model_save_path: str, | ||
): | ||
print(f"Loading model on {device}...") | ||
|
@@ -231,8 +241,9 @@ def wikitext2_ppl( | |
t0 = time.time() | ||
quantize_( | ||
model, | ||
awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq), | ||
awq_uintx(quant_dtype=quant_dtype, group_size=group_size, use_hqq=use_hqq, zero_point_domain=zero_point_domain_dict[zero_point_domin]), | ||
is_observed_linear, | ||
torch.device(device), | ||
) | ||
print(f"time for quantization: {time.time() - t0:.02f} seconds") | ||
if model_save_path is not None: | ||
|
@@ -242,10 +253,15 @@ def wikitext2_ppl( | |
group_size = int(quant.split("-")[1]) | ||
use_hqq = "hqq" in quant | ||
print(f"running {quant} quantization with group size {group_size}") | ||
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) | ||
int4_weight_only_config = int4_weight_only(group_size=group_size, use_hqq=use_hqq) | ||
if "xpu" in device: | ||
int4_weight_only_config.layout = Int4XPULayout() | ||
int4_weight_only_config.layout.zero_point_domin = zero_point_domain_dict["zero_point_domin"] | ||
quantize_(model, int4_weight_only_config) | ||
if compile: | ||
model = torch.compile(model) | ||
|
||
print("model:", model) | ||
return benchmark(model, tokenizer, sequence_length, tasks=tasks, device=device) | ||
|
||
|
||
|
@@ -299,6 +315,13 @@ def wikitext2_ppl( | |
action="store_true", | ||
help="Flag to indicate if compilation is required.", | ||
) | ||
parser.add_argument( | ||
"--zero_point_domin", | ||
type=str, | ||
default="float", | ||
choices=['float', 'int', 'none'], | ||
help="Zero point type. Default is 'float'.", | ||
) | ||
parser.add_argument( | ||
"--model_save_path", | ||
type=str, | ||
|
@@ -320,6 +343,7 @@ def wikitext2_ppl( | |
args.precision, | ||
args.seq_len, | ||
args.compile, | ||
args.zero_point_domin, | ||
args.model_save_path, | ||
) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,6 +127,11 @@ def cuda(self): | |
val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values | ||
] | ||
|
||
def xpu(self): | ||
self.values = [ | ||
val.xpu() if isinstance(val, torch.Tensor) else val for val in self.values | ||
] | ||
|
||
|
||
def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None): | ||
if dtype is not None and tensor_arg.dtype != dtype: | ||
|
@@ -415,25 +420,6 @@ def unpack_tinygemm_scales_and_zeros(scales_and_zeros): | |
return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1) | ||
|
||
|
||
def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False): | ||
assert weight.device.type == "xpu" | ||
|
||
if zero_point_domain_is_int: | ||
# int_data = weight.to(dtype=torch.uint8) | ||
int_data = (weight[::, 1::2] << 4 | weight[::, ::2]).to(torch.uint8) | ||
packed_weight = torch.ops.aten._convert_weight_to_int4pack( | ||
int_data, | ||
8, # TODO:remove | ||
) | ||
else: | ||
out = weight.to(dtype=torch.uint8) | ||
out = (out[::, 1::2] << 4 | out[::, ::2]).to(torch.uint8) | ||
packed_weight = out.view(torch.int32) | ||
|
||
# Second, N * K/2 uint8 -> N * K/8 int32 | ||
return packed_weight | ||
|
||
|
||
def groupwise_affine_quantize_tensor_from_qparams( | ||
w, scales, zeros, n_bit=4, groupsize=128, zero_point_domain=ZeroPointDomain.FLOAT | ||
): | ||
|
@@ -473,6 +459,8 @@ def groupwise_affine_quantize_tensor_from_qparams( | |
not (check_xpu_version(int_data.device)) | ||
): | ||
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) | ||
if check_xpu_version(int_data.device): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should probably encapsulate these better when we have a better design for layout conversions: #2249 |
||
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) | ||
return int_data | ||
|
||
|
||
|
@@ -491,7 +479,6 @@ def groupwise_affine_dequantize_tensor_from_qparams( | |
TORCH_VERSION_AT_LEAST_2_5 | ||
and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) | ||
and not (check_cpu_version(w_int4x8.device)) | ||
and not (check_xpu_version(w_int4x8.device)) | ||
): | ||
data = w_int4x8.to(torch.int32) | ||
high_bits = data >> 4 | ||
|
@@ -501,8 +488,13 @@ def groupwise_affine_dequantize_tensor_from_qparams( | |
dtype=torch.int32, | ||
device=w_int4x8.device, | ||
) | ||
w_int32[::, ::2] = high_bits | ||
w_int32[::, 1::2] = low_bits | ||
if (not (check_xpu_version(w_int4x8.device)) | ||
): | ||
w_int32[::, ::2] = high_bits | ||
w_int32[::, 1::2] = low_bits | ||
else: | ||
w_int32[::, ::2] = low_bits | ||
w_int32[::, 1::2] = high_bits | ||
else: | ||
w_int32 = w_int4x8 | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done