@@ -45,6 +45,16 @@ $if MODE == "per_token":
45
45
int quant_min;
46
46
int quant_max;
47
47
};
48
+ $if MODE == "per_channel":
49
+ ${layout_declare_tensor(B, "r", "t_scale", "float ", "buffer ")}
50
+ ${layout_declare_tensor(B, "r", "t_zero_point", "int ", "buffer ")}
51
+
52
+ layout (push_constant) uniform restrict Block {
53
+ int axis;
54
+ int num_channels;
55
+ int quant_min;
56
+ int quant_max;
57
+ };
48
58
49
59
${layout_declare_ubo(B, "ivec3 ", "t_in_limits")}
50
60
${layout_declare_ubo(B, "ivec3 ", "t_out_limits")}
@@ -147,7 +157,7 @@ void dequantize_per_tensor() {
147
157
write_texel(t_out, pos, outtex);
148
158
}
149
159
150
- #else
160
+ #elif defined(per_token)
151
161
152
162
void dequantize_per_token() {
153
163
const ivec3 pos = ivec3 (gl_GlobalInvocationID);
@@ -189,6 +199,97 @@ void dequantize_per_token() {
189
199
write_texel(t_out, pos, outtex);
190
200
}
191
201
202
+ #else // per_channel
203
+
204
+ void dequantize_per_channel() {
205
+ const ivec3 pos = ivec3 (gl_GlobalInvocationID);
206
+
207
+ if (any (greaterThanEqual (pos, t_in_limits))) {
208
+ return ;
209
+ }
210
+
211
+ IVEC4_T intex = load_texel(t_in, pos);
212
+ FVEC4_T outtex;
213
+
214
+ // Calculate channel index based on the dequantization axis (already converted to WHCN)
215
+ // The axis parameter is now in WHCN coordinate system:
216
+ // axis 0 -> W dimension (pos.x)
217
+ // axis 1 -> H dimension (pos.y)
218
+ // axis 2 -> C dimension (pos.z)
219
+ // axis 3 -> N dimension (batch folding in texture storage)
220
+
221
+ if (axis == 0 ) {
222
+ // Width dimension - each texel component has different channel index
223
+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
224
+ IN_T qvalue = IN_T(intex[i]);
225
+ int channel_idx = pos.x * 4 + i;
226
+ channel_idx = min (channel_idx, num_channels - 1 );
227
+
228
+ float scale_val = t_scale[channel_idx];
229
+ int zero_point_val = t_zero_point[channel_idx];
230
+ OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
231
+ $if OUT_DTYPE == "double ":
232
+ outtex[i] = float (value);
233
+ $else :
234
+ outtex[i] = value;
235
+ }
236
+ } else if (axis == 1 ) {
237
+ int channel_idx = pos.y;
238
+ channel_idx = min (channel_idx, num_channels - 1 );
239
+ float scale_val = t_scale[channel_idx];
240
+ int zero_point_val = t_zero_point[channel_idx];
241
+
242
+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
243
+ IN_T qvalue = IN_T(intex[i]);
244
+ OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
245
+ $if OUT_DTYPE == "double ":
246
+ outtex[i] = float (value);
247
+ $else :
248
+ outtex[i] = value;
249
+ }
250
+ } else if (axis == 2 ) {
251
+ // Channel dimension - for 4D tensors, need to account for batch-channel folding
252
+ // The Z coordinate contains folded batch*channel information
253
+ // We need to extract the actual channel index from the folded dimension
254
+ int folded_idx = pos.z;
255
+ int channel_idx = folded_idx % num_channels;
256
+
257
+ float scale_val = t_scale[channel_idx];
258
+ int zero_point_val = t_zero_point[channel_idx];
259
+
260
+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
261
+ IN_T qvalue = IN_T(intex[i]);
262
+ OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
263
+ $if OUT_DTYPE == "double ":
264
+ outtex[i] = float (value);
265
+ $else :
266
+ outtex[i] = value;
267
+ }
268
+ } else if (axis == 3 ) {
269
+ // Batch dimension - for 4D tensors, need to account for batch-channel folding
270
+ // The Z coordinate contains folded batch*channel information
271
+ // We need to extract the actual channel index from the folded dimension
272
+ int folded_idx = pos.z;
273
+ // In this case num_channels actually corresponds to the number of channels
274
+ // the C dimension N(C)HW
275
+ int channel_idx = folded_idx / num_channels;
276
+
277
+ float scale_val = t_scale[channel_idx];
278
+ int zero_point_val = t_zero_point[channel_idx];
279
+
280
+ [[unroll]] for (int i = 0 ; i < 4 ; ++ i) {
281
+ IN_T qvalue = IN_T(intex[i]);
282
+ OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
283
+ $if OUT_DTYPE == "double ":
284
+ outtex[i] = float (value);
285
+ $else :
286
+ outtex[i] = value;
287
+ }
288
+ }
289
+
290
+ write_texel(t_out, pos, outtex);
291
+ }
292
+
192
293
#endif
193
294
194
295
void main() {
0 commit comments