Skip to content

[ET-VK][Ops] dequantize_per_channel shaders and impl #12435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 14, 2025
51 changes: 50 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ $if MODE == "per_token":
int quant_min;
int quant_max;
};
$if MODE == "per_channel":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
int axis;
int num_channels;
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "int", "out_numel")}
${layout_declare_ubo(B, "ivec4", "t_in_sizes")}
Expand Down Expand Up @@ -141,7 +151,7 @@ void dequantize_per_tensor() {
t_out[out_bufi] = value;
}

#else
#elif defined(per_token)

void dequantize_per_token() {
const int out_bufi = int(gl_GlobalInvocationID.x);
Expand Down Expand Up @@ -176,6 +186,45 @@ void dequantize_per_token() {
t_out[out_bufi] = value;
}

#else // per_channel

void dequantize_per_channel() {
const int out_bufi = int(gl_GlobalInvocationID.x);

if (out_bufi >= out_numel) {
return;
}

const ivec4 out_tidx = bufi_to_tidx(out_bufi, t_out_strides, out_dim_order);
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);

IN_T qvalue = t_in[in_bufi];

// Calculate channel index based on the dequantization axis (already converted to WHCN)
// The axis parameter is now in WHCN coordinate system:
// axis 0 -> W dimension (tidx.x)
// axis 1 -> H dimension (tidx.y)
// axis 2 -> C dimension (tidx.z)
// axis 3 -> N dimension (tidx.w)
int channel_idx = 0;

if (axis == 0) {
channel_idx = out_tidx.x;
} else if (axis == 1) {
channel_idx = out_tidx.y;
} else if (axis == 2) {
channel_idx = out_tidx.z;
} else if (axis == 3) {
channel_idx = out_tidx.w;
}

channel_idx = min(channel_idx, num_channels - 1);

OUT_T value = dequantize_val(qvalue, t_scale[channel_idx], t_zero_point[channel_idx]);

t_out[out_bufi] = value;
}

#endif

void main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ dequantize_buffer:
MODE: per_tensor
- NAME: dequantize_per_token_buffer
MODE: per_token
- NAME: dequantize_per_channel_buffer
MODE: per_channel
103 changes: 102 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ $if MODE == "per_token":
int quant_min;
int quant_max;
};
$if MODE == "per_channel":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
int axis;
int num_channels;
int quant_min;
int quant_max;
};

${layout_declare_ubo(B, "ivec3", "t_in_limits")}
${layout_declare_ubo(B, "ivec3", "t_out_limits")}
Expand Down Expand Up @@ -147,7 +157,7 @@ void dequantize_per_tensor() {
write_texel(t_out, pos, outtex);
}

#else
#elif defined(per_token)

void dequantize_per_token() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);
Expand Down Expand Up @@ -189,6 +199,97 @@ void dequantize_per_token() {
write_texel(t_out, pos, outtex);
}

#else // per_channel

void dequantize_per_channel() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, t_in_limits))) {
return;
}

IVEC4_T intex = load_texel(t_in, pos);
FVEC4_T outtex;

// Calculate channel index based on the dequantization axis (already converted to WHCN)
// The axis parameter is now in WHCN coordinate system:
// axis 0 -> W dimension (pos.x)
// axis 1 -> H dimension (pos.y)
// axis 2 -> C dimension (pos.z)
// axis 3 -> N dimension (batch folding in texture storage)

if (axis == 0) {
// Width dimension - each texel component has different channel index
[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
int channel_idx = pos.x * 4 + i;
channel_idx = min(channel_idx, num_channels - 1);

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
} else if (axis == 1) {
int channel_idx = pos.y;
channel_idx = min(channel_idx, num_channels - 1);
float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
} else if (axis == 2) {
// Channel dimension - for 4D tensors, need to account for batch-channel folding
// The Z coordinate contains folded batch*channel information
// We need to extract the actual channel index from the folded dimension
int folded_idx = pos.z;
int channel_idx = folded_idx % num_channels;

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
} else if (axis == 3) {
// Batch dimension - for 4D tensors, need to account for batch-channel folding
// The Z coordinate contains folded batch*channel information
// We need to extract the actual channel index from the folded dimension
int folded_idx = pos.z;
// In this case num_channels actually corresponds to the number of channels
// the C dimension N(C)HW
int channel_idx = folded_idx / num_channels;

float scale_val = t_scale[channel_idx];
int zero_point_val = t_zero_point[channel_idx];

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
outtex[i] = value;
}
}

write_texel(t_out, pos, outtex);
}

#endif

void main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ dequantize_texture:
MODE: per_tensor
- NAME: dequantize_per_token_texture3d
MODE: per_token
- NAME: dequantize_per_channel_texture3d
MODE: per_channel
Loading
Loading