-
Notifications
You must be signed in to change notification settings - Fork 304
Description
We'd like to add GPTQ support to the spares-marlin kernels so we can get (accurate) perplexity numbers from the neuralmagic 2:4 sparse checkpoint.
@HDCharles wrote GPTQ be easy to generalize to other techniques beyond the int4 quantization mode i started with. So you have a base class GPTQQuantizer and a quantization technique specific which needs to have the specifics defined for your use case, i.e. Int4WeightOnlyGPTQQuantizer
https://github.com/pytorch/ao/blob/main/torchao/quantization/GPTQ_MT.py#L561
which inherits from the base class, the base class basically has all the GPTQ logic whereas the specific class has the individual functions for the use case. These are
get_qparams_func, quantize_func, dequantize_func, combine_qparams_list_func, skip_layer_func, make_qtensor. These have to be specified by you for your particular use case.
the way GPTQ works is that it looks at a group of columns, and uses get_qparams_func to calculate the qparams for them. Then it goes column by column and applies quantize_func and dequantize_func to the column, the result has to be a full precision value. It then does the magical part of GPTQ 1) it updates the rest of teh columns of the weight tensor to maintain the hessian and then it moves to the next column. 2) The next column gets quantized and dequantized, taking into account the altered values to maintain the hessian...etc.
After it goes through all the columns it gets to a new group, finds new qparams, (each set of qparams gets appended to a list) and then at the end, it applies make_qtensor using the final dequantized weight and ALL the qparams after they've been processed by the combine_qparams_list_func.
So for int4 quantization you can see what is being used for each of these steps. Sparsity is a little different because whereas for quantization, when you pick qparams for a group, those qparams get applied uniformly to all values. But for sparsity, where you have a sparse mask, you have to know which column you're on in order to know how to sparsify your tensor. To get around this <@1213148470664495114> and i were theorizing something like the following:
the 3 ones that mostly define the process are
def get_qparam_func(x)
sparse_mask = func_to_get_sparse_mask(x)
x_pruned = apply_sparse_mask(x, sparse_mask) # this removes the zeros so that they don't affect the calculation of quantization parameters
scale, zero_point = call_the_same_get_qparam_func_as_int4(x_pruned)
count=0
return (scale, zero_point, sparse_mask, count) #we need to keep track of which column we're on so we use a qparam that will be continuously updated so we don't have to edit the algorithm itself
def quant_func(x, qparams)
scale, zp, sparse_mask, count = qparams
xq = call_the_same_quant_func_as_int4(x, (scale, zp))
xq_s = xq * sparse_mask[:, count] # technically we want to set them to the q value that maps to 0 rather than setting them to 0, but since the actual thing we care about is quant+dequant numerics, its fine to set the sparsified values to anything since we zero them out in the dequant func.
return xq_s
def dequant_func(xq_s, qparams):
x_dq = use_current_dequant_func(xq_s, qparams[:2])
x_dq_s = x_dq * sparse_mask[:, count] # we need to sparsify the final result so it is correct
count+=1
return x_dq_s
To test this, you can check out my branch, https://github.com/pytorch/ao/tree/jcaip/sparse-benchmarking-updates, which updates the benchmarking script for sparsity cc @Diogo-V