Skip to content

[ET-VK][Ops] quantize_per_tensor.tensor variant #12208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 14, 2025
7 changes: 4 additions & 3 deletions backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}

$if MODE == "per_tensor":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
float scale;
int zero_point;
int quant_min;
int quant_max;
};
Expand Down Expand Up @@ -142,7 +143,7 @@ void quantize_per_tensor() {
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);

IN_T value = t_in[in_bufi];
OUT_T qvalue = quantize_val(value, scale, zero_point);
OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]);

t_out[out_bufi] = qvalue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}

$if MODE == "per_tensor":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
float scale;
int zero_point;
int quant_min;
int quant_max;
};
Expand Down Expand Up @@ -146,7 +147,7 @@ void quantize_per_tensor() {

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T value = IN_T(intex[i]);
OUT_T qvalue = quantize_val(value, scale, zero_point);
OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]);
outtex[i] = qvalue;
}
write_texel(t_out, pos, outtex);
Expand Down
25 changes: 9 additions & 16 deletions backends/vulkan/runtime/graph/ops/impl/Quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ void add_quantize_per_tensor_node(
add_dtype_suffix(kernel_name, graph.dtype_of(input));
add_dtype_suffix(kernel_name, graph.dtype_of(output));

float scale_val = static_cast<float>(graph.get_double(scale));
int zero_point_val = static_cast<int>(graph.get_int(zero_point));
int quant_min_val = static_cast<int>(graph.get_int(quant_min));
int quant_max_val = static_cast<int>(graph.get_int(quant_max));

Expand All @@ -102,23 +100,16 @@ void add_quantize_per_tensor_node(
graph.strides_ubo(input),
graph.sizes_ubo(output),
graph.strides_ubo(output)};
push_constants = {
PushConstantDataInfo(&scale_val, sizeof(float)),
PushConstantDataInfo(&zero_point_val, sizeof(int)),
PushConstantDataInfo(&quant_min_val, sizeof(int)),
PushConstantDataInfo(&quant_max_val, sizeof(int)),
};
} else {
param_ubos = {
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
push_constants = {
PushConstantDataInfo(&scale_val, sizeof(float)),
PushConstantDataInfo(&zero_point_val, sizeof(int)),
PushConstantDataInfo(&quant_min_val, sizeof(int)),
PushConstantDataInfo(&quant_max_val, sizeof(int)),
};
}

push_constants = {
PushConstantDataInfo(&quant_min_val, sizeof(int)),
PushConstantDataInfo(&quant_max_val, sizeof(int)),
};

vkapi::SpecVarList spec_vars = {
graph.hashed_layout_of(output),
graph.hashed_layout_of(input),
Expand All @@ -130,7 +121,9 @@ void add_quantize_per_tensor_node(
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{output, vkapi::kWrite}, {input, vkapi::kRead}},
{{output, vkapi::kWrite},
{input, vkapi::kRead},
{{scale, zero_point}, vkapi::kRead}},
// Shader param buffers
param_ubos,
// Push Constants
Expand Down Expand Up @@ -489,7 +482,7 @@ void quantize_per_channel_impl(

REGISTER_OPERATORS {
VK_REGISTER_OP(
quantized_decomposed.quantize_per_tensor.default,
quantized_decomposed.quantize_per_tensor.tensor,
quantize_per_tensor_impl);
VK_REGISTER_OP(
quantized_decomposed.quantize_per_token.default, quantize_per_token_impl);
Expand Down
Loading
Loading