Skip to content

Commit 63431bd

Browse files
authored
[ET-VK][Ops] dequantize_per_channel shaders and impl (#12435)
# Context We need to enable the core logic for dequantize_per_channel in the vulkan shader. This implements the shader itself and its cpp header. TODO: add more of a description regarding the operator # Changes This creates an extension of the existing files for dequantize_per_channel. Differential Revision: [D77746141](https://our.internmc.facebook.com/intern/diff/D77746141/)
1 parent 8f3eb3e commit 63431bd

File tree

6 files changed

+1058
-5
lines changed

6 files changed

+1058
-5
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,16 @@ $if MODE == "per_token":
4242
int quant_min;
4343
int quant_max;
4444
};
45+
$if MODE == "per_channel":
46+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
47+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
48+
49+
layout(push_constant) uniform restrict Block {
50+
int axis;
51+
int num_channels;
52+
int quant_min;
53+
int quant_max;
54+
};
4555

4656
${layout_declare_ubo(B, "int", "out_numel")}
4757
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
@@ -141,7 +151,7 @@ void dequantize_per_tensor() {
141151
t_out[out_bufi] = value;
142152
}
143153

144-
#else
154+
#elif defined(per_token)
145155

146156
void dequantize_per_token() {
147157
const int out_bufi = int(gl_GlobalInvocationID.x);
@@ -176,6 +186,45 @@ void dequantize_per_token() {
176186
t_out[out_bufi] = value;
177187
}
178188

189+
#else // per_channel
190+
191+
void dequantize_per_channel() {
192+
const int out_bufi = int(gl_GlobalInvocationID.x);
193+
194+
if (out_bufi >= out_numel) {
195+
return;
196+
}
197+
198+
const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
199+
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
200+
201+
IN_T qvalue = t_in[in_bufi];
202+
203+
// Calculate channel index based on the dequantization axis (already converted to WHCN)
204+
// The axis parameter is now in WHCN coordinate system:
205+
// axis 0 -> W dimension (tidx.x)
206+
// axis 1 -> H dimension (tidx.y)
207+
// axis 2 -> C dimension (tidx.z)
208+
// axis 3 -> N dimension (tidx.w)
209+
int channel_idx = 0;
210+
211+
if (axis == 0) {
212+
channel_idx = out_tidx.x;
213+
} else if (axis == 1) {
214+
channel_idx = out_tidx.y;
215+
} else if (axis == 2) {
216+
channel_idx = out_tidx.z;
217+
} else if (axis == 3) {
218+
channel_idx = out_tidx.w;
219+
}
220+
221+
channel_idx = min(channel_idx, num_channels - 1);
222+
223+
OUT_T value = dequantize_val(qvalue, t_scale[channel_idx], t_zero_point[channel_idx]);
224+
225+
t_out[out_bufi] = value;
226+
}
227+
179228
#endif
180229

181230
void main() {

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@ dequantize_buffer:
1717
MODE: per_tensor
1818
- NAME: dequantize_per_token_buffer
1919
MODE: per_token
20+
- NAME: dequantize_per_channel_buffer
21+
MODE: per_channel

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ $if MODE == "per_token":
4545
int quant_min;
4646
int quant_max;
4747
};
48+
$if MODE == "per_channel":
49+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
50+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
51+
52+
layout(push_constant) uniform restrict Block {
53+
int axis;
54+
int num_channels;
55+
int quant_min;
56+
int quant_max;
57+
};
4858

4959
${layout_declare_ubo(B, "ivec3", "t_in_limits")}
5060
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
@@ -147,7 +157,7 @@ void dequantize_per_tensor() {
147157
write_texel(t_out, pos, outtex);
148158
}
149159

150-
#else
160+
#elif defined(per_token)
151161

152162
void dequantize_per_token() {
153163
const ivec3 pos = ivec3(gl_GlobalInvocationID);
@@ -189,6 +199,97 @@ void dequantize_per_token() {
189199
write_texel(t_out, pos, outtex);
190200
}
191201

202+
#else // per_channel
203+
204+
void dequantize_per_channel() {
205+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
206+
207+
if (any(greaterThanEqual(pos, t_in_limits))) {
208+
return;
209+
}
210+
211+
IVEC4_T intex = load_texel(t_in, pos);
212+
FVEC4_T outtex;
213+
214+
// Calculate channel index based on the dequantization axis (already converted to WHCN)
215+
// The axis parameter is now in WHCN coordinate system:
216+
// axis 0 -> W dimension (pos.x)
217+
// axis 1 -> H dimension (pos.y)
218+
// axis 2 -> C dimension (pos.z)
219+
// axis 3 -> N dimension (batch folding in texture storage)
220+
221+
if (axis == 0) {
222+
// Width dimension - each texel component has different channel index
223+
[[unroll]] for (int i = 0; i < 4; ++i) {
224+
IN_T qvalue = IN_T(intex[i]);
225+
int channel_idx = pos.x * 4 + i;
226+
channel_idx = min(channel_idx, num_channels - 1);
227+
228+
float scale_val = t_scale[channel_idx];
229+
int zero_point_val = t_zero_point[channel_idx];
230+
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
231+
$if OUT_DTYPE == "double":
232+
outtex[i] = float(value);
233+
$else:
234+
outtex[i] = value;
235+
}
236+
} else if (axis == 1) {
237+
int channel_idx = pos.y;
238+
channel_idx = min(channel_idx, num_channels - 1);
239+
float scale_val = t_scale[channel_idx];
240+
int zero_point_val = t_zero_point[channel_idx];
241+
242+
[[unroll]] for (int i = 0; i < 4; ++i) {
243+
IN_T qvalue = IN_T(intex[i]);
244+
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
245+
$if OUT_DTYPE == "double":
246+
outtex[i] = float(value);
247+
$else:
248+
outtex[i] = value;
249+
}
250+
} else if (axis == 2) {
251+
// Channel dimension - for 4D tensors, need to account for batch-channel folding
252+
// The Z coordinate contains folded batch*channel information
253+
// We need to extract the actual channel index from the folded dimension
254+
int folded_idx = pos.z;
255+
int channel_idx = folded_idx % num_channels;
256+
257+
float scale_val = t_scale[channel_idx];
258+
int zero_point_val = t_zero_point[channel_idx];
259+
260+
[[unroll]] for (int i = 0; i < 4; ++i) {
261+
IN_T qvalue = IN_T(intex[i]);
262+
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
263+
$if OUT_DTYPE == "double":
264+
outtex[i] = float(value);
265+
$else:
266+
outtex[i] = value;
267+
}
268+
} else if (axis == 3) {
269+
// Batch dimension - for 4D tensors, need to account for batch-channel folding
270+
// The Z coordinate contains folded batch*channel information
271+
// We need to extract the actual channel index from the folded dimension
272+
int folded_idx = pos.z;
273+
// In this case num_channels actually corresponds to the number of channels
274+
// the C dimension N(C)HW
275+
int channel_idx = folded_idx / num_channels;
276+
277+
float scale_val = t_scale[channel_idx];
278+
int zero_point_val = t_zero_point[channel_idx];
279+
280+
[[unroll]] for (int i = 0; i < 4; ++i) {
281+
IN_T qvalue = IN_T(intex[i]);
282+
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
283+
$if OUT_DTYPE == "double":
284+
outtex[i] = float(value);
285+
$else:
286+
outtex[i] = value;
287+
}
288+
}
289+
290+
write_texel(t_out, pos, outtex);
291+
}
292+
192293
#endif
193294

194295
void main() {

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@ dequantize_texture:
1717
MODE: per_tensor
1818
- NAME: dequantize_per_token_texture3d
1919
MODE: per_token
20+
- NAME: dequantize_per_channel_texture3d
21+
MODE: per_channel

0 commit comments

Comments
 (0)