Skip to content

Enable Int4WeightOnlyGPTQQuantizer on Intel GPU. #2200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 154 additions & 36 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .quant_primitives import (
MappingType,
dequantize_affine,
ZeroPointDomain,
)
from .unified import Quantizer
from .utils import (
Expand All @@ -38,6 +39,7 @@
groupwise_affine_quantize_tensor,
groupwise_affine_quantize_tensor_from_qparams,
pack_tinygemm_scales_and_zeros,
align_tinygemm_scales_and_zeros,
per_token_dynamic_quant,
)

Expand Down Expand Up @@ -75,18 +77,19 @@ def __init__(
percdamp=0.01,
groupsize=128,
):
self.device = self.get_device(model, inputs)
self.id_to_name = {
id(value): name for name, value in dict(model.named_parameters()).items()
}

# trace model for one input
one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16]
one_input = [multi.values[0] for multi in inputs] # pyre-ignore[16]
# needed for GPTQ on the torchao llama model
import torchao

torchao._models.llama.model.use_index_put_for_kv_cache = True
exported_model = torch._dynamo.export(
model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake"
model, aten_graph=True, pre_dispatch=True, tracing_mode="fake"
)(*one_input)
super().__init__(exported_model.graph_module)

Expand All @@ -100,6 +103,19 @@ def __init__(
self.inputs = inputs
self.gptq_done = False
self.debug = False


def get_device(self, model, inputs: _MultiInput):
for name, param in model.named_parameters():
if isinstance(param, torch.Tensor):
return param.device

for multi in inputs:
if isinstance(multi.values[0], torch.Tensor):
return multi.values[0].device

return torch.device("cpu")


def configure_quantization_mode(
self,
Expand Down Expand Up @@ -163,16 +179,16 @@ def get_quantized_state_dict(self):
return quantized_state_dict

def call_function(self, target, args, kwargs, already_quantized=False): # noqa: C901
def tensors_to_cuda(args):
def tensors_to_device(args):
new_args = []
for x in args:
new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x)
new_args.append(x.to(self.device) if isinstance(x, torch.Tensor) else x)
return new_args

# flatten args and kwargs together
flat_args, spec = tree_flatten((args, kwargs))
# move all single tensors to cuda, will move _MultiInputs to cuda one at a time
flat_args = tensors_to_cuda(flat_args)
flat_args = tensors_to_device(flat_args)

has_multi_input = _MultiInput in [type(x) for x in flat_args]
if has_multi_input:
Expand Down Expand Up @@ -212,7 +228,7 @@ def tensors_to_cuda(args):
total_batches = 0

for inp in transposed_args:
inp = tensors_to_cuda(inp)
inp = tensors_to_device(inp)
cur_args, cur_kwargs = tree_unflatten(inp, spec)

if quantize_linear: # calculate H instead of output (will run the linear eventually with updated weight)
Expand Down Expand Up @@ -283,7 +299,7 @@ def SQNR(x, y):
"SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after)
) # matches
print(
"SQNR for weight (can be low)", SQNR(W, DQ.cuda())
"SQNR for weight (can be low)", SQNR(W, DQ.to(self.device))
) # fine to not match
print(
"SQNR for output with GPTQ (hopefully 35+)",
Expand Down Expand Up @@ -385,7 +401,12 @@ def faster_quant(self, H, W):

W[:, i2:] -= Err1.to(Hinv.dtype).matmul(Hinv[i1:i2, i2:])

torch.cuda.synchronize()
if 'cuda' in self.device.type:
torch.cuda.synchronize()
elif 'xpu' in self.device.type:
torch.xpu.synchronize()
else:
pass

if all_qparams == []:
all_qparams.append(cur_qparams)
Expand Down Expand Up @@ -561,6 +582,30 @@ def linear_forward_int4(
return c


def linear_forward_int4_zero_domain(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def linear_forward_int4_zero_domain(
def linear_forward_int4_zero_point_domain_int(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify done

x: torch.Tensor,
weight_int4pack: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
out_features: int,
groupsize: int,
precision: torch.dtype = torch.bfloat16,
scales_precision: torch.dtype = torch.bfloat16,
):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros(
x.contiguous().to(precision),
weight_int4pack,
groupsize,
scales.to(scales_precision),
zeros.to(torch.int8),
).to(dtype=x.dtype)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c


class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
Expand All @@ -579,6 +624,7 @@ def __init__(
inner_k_tiles: int = 8,
precision: torch.dtype = torch.bfloat16,
scales_precision: torch.dtype = torch.bfloat16,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT,
) -> None:
super().__init__()
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
Expand All @@ -594,6 +640,7 @@ def __init__(
self.inner_k_tiles = inner_k_tiles
self.precision = precision
self.scales_precision = scales_precision
self.zero_point_domain = zero_point_domain

if dtype is not None:
raise ValueError("Please specify 'precision' instead of 'dtype'")
Expand All @@ -614,6 +661,18 @@ def __init__(
device=device,
),
)
elif is_device(device.type, "xpu"):
self.register_buffer(
"weight",
torch.zeros(
(
out_features,
in_features // 8,
),
dtype=torch.int32,
device=device,
),
)
else:
self.register_buffer(
"weight",
Expand All @@ -629,27 +688,59 @@ def __init__(
),
)
self.dtype = dtype
self.register_buffer(
"scales_and_zeros",
torch.zeros(
(in_features // groupsize, out_features, 2),
dtype=self.scales_precision,
device=device,
),
)
if self.zero_point_domain == ZeroPointDomain.INT:
self.register_buffer(
"scales",
torch.zeros(
(in_features // groupsize, out_features),
dtype=self.scales_precision,
device=device,
),
)

self.register_buffer(
"zeros",
torch.zeros(
(in_features // groupsize, out_features),
dtype=torch.int8,
device=device,
),
)
else:
self.register_buffer(
"scales_and_zeros",
torch.zeros(
(in_features // groupsize, out_features, 2),
dtype=self.scales_precision,
device=device,
),
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.padding:
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_int4(
input,
self.weight,
self.scales_and_zeros,
self.out_features,
self.groupsize,
self.precision,
self.scales_precision,
)

if self.zero_point_domain != ZeroPointDomain.INT:
return linear_forward_int4(
input,
self.weight,
self.scales_and_zeros,
self.out_features,
self.groupsize,
self.precision,
self.scales_precision,
)
else:
return linear_forward_int4_zero_domain(
input,
self.weight,
self.scales,
self.zeros,
self.out_features,
self.groupsize,
self.precision,
self.scales_precision,
)


def _replace_linear_int4(
Expand All @@ -662,6 +753,7 @@ def _replace_linear_int4(
scales_precision: torch.dtype = torch.bfloat16,
linear_class: Type[torch.nn.Module] = WeightOnlyInt4Linear,
copy_weights: bool = False,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT,
):
for name, child in module.named_children():
# TODO: support linear bias
Expand All @@ -683,6 +775,7 @@ def _replace_linear_int4(
inner_k_tiles=inner_k_tiles,
precision=precision,
scales_precision=scales_precision,
zero_point_domain = zero_point_domain,
)
# TODO: merge with 8da4w?
# In distributed training, the model may be instantiated
Expand All @@ -702,11 +795,17 @@ def _replace_linear_int4(
scales_precision,
linear_class,
copy_weights,
zero_point_domain = zero_point_domain,
)


def replace_linear_int4(
module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func=None
module,
groupsize,
inner_k_tiles,
padding_allowed,
skip_layer_func=None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT,
):
_replace_linear_int4(
module,
Expand All @@ -715,6 +814,7 @@ def replace_linear_int4(
padding_allowed,
skip_layer_func,
linear_class=WeightOnlyInt4Linear,
zero_point_domain = zero_point_domain,
)


Expand Down Expand Up @@ -830,22 +930,24 @@ def __init__(
groupsize=64,
inner_k_tiles=8,
padding_allowed=True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.FLOAT,
device: torch.device = torch.device("cuda"),
):
self.blocksize = blocksize
self.percdamp = percdamp
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding_allowed = padding_allowed
self.zero_point_domain = zero_point_domain
self.device = device
self.act_fake_quant_func = None
n_bit = 4
self.get_qparams_func = lambda w: get_groupwise_affine_qparams(
w, n_bit, groupsize
w, n_bit, groupsize, zero_point_domain=self.zero_point_domain,
)
self.quantize_func = (
lambda w, qparams: groupwise_affine_quantize_tensor_from_qparams(
w, qparams[0], qparams[1], n_bit, groupsize
w, qparams[0], qparams[1], n_bit, groupsize, zero_point_domain=self.zero_point_domain,
)
)
self.dequantize_func = (
Expand All @@ -855,6 +957,7 @@ def __init__(
qparams[1],
n_bit,
groupsize,
zero_point_domain = self.zero_point_domain,
)
)
self.combine_qparams_list_func = lambda qparams_list: [
Expand Down Expand Up @@ -886,14 +989,28 @@ def make_names_and_values_dict_func(q, qparams):
F.pad(q, pad=(0, delta_k)), inner_k_tiles
)
scales = qparams[0].to(torch.bfloat16).to(self.device)
zeros = qparams[1].to(torch.bfloat16).to(self.device)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
# how many new groups we need for padded weight
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
final_s_and_z = F.pad(
scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1
)
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
if zero_point_domain == ZeroPointDomain.FLOAT:
zeros = qparams[1].to(torch.bfloat16).to(self.device)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
# how many new groups we need for padded weight
delta_groups = new_k // groupsize - scales_and_zeros.shape[0]
final_s_and_z = F.pad(
scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1
)
return {"weight": final_q, "scales_and_zeros": final_s_and_z}
if zero_point_domain == ZeroPointDomain.INT:
zeros = qparams[1].to(torch.int8).to(self.device)
scales, zeros = align_tinygemm_scales_and_zeros(scales, zeros)
# how many new groups we need for padded weight
delta_groups = new_k // groupsize - scales.shape[0]
final_s = F.pad(
scales, pad=(0, 0, 0, delta_groups), value=1
)
final_z = F.pad(
zeros, pad=(0, 0, 0, delta_groups), value=1
)
return {"weight": final_q, "scales": final_s, "zeros": final_z}


self.make_names_and_values_dict_func = make_names_and_values_dict_func
super().__init__()
Expand All @@ -905,6 +1022,7 @@ def _convert_for_runtime(self, model):
self.inner_k_tiles,
self.padding_allowed,
skip_layer_func=self.skip_layer_func,
zero_point_domain = self.zero_point_domain,
)
return model

Expand Down
Loading