Skip to content

[ET-VK][Ops] dequantize_per_tensor.tensor variant #12209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 14, 2025
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}

$if MODE == "per_tensor":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
float scale;
int zero_point;
int quant_min;
int quant_max;
};
Expand Down Expand Up @@ -146,7 +147,7 @@ void dequantize_per_tensor() {
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);

IN_T qvalue = t_in[in_bufi];
OUT_T value = dequantize_val(qvalue, scale, zero_point);
OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]);

t_out[out_bufi] = value;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}

$if MODE == "per_tensor":
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}

layout(push_constant) uniform restrict Block {
float scale;
int zero_point;
int quant_min;
int quant_max;
};
Expand Down Expand Up @@ -148,7 +149,8 @@ void dequantize_per_tensor() {

[[unroll]] for (int i = 0; i < 4; ++i) {
IN_T qvalue = IN_T(intex[i]);
OUT_T value = dequantize_val(qvalue, scale, zero_point);
OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]);

$if OUT_DTYPE == "double":
outtex[i] = float(value);
$else:
Expand Down
25 changes: 9 additions & 16 deletions backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ void add_dequantize_per_tensor_node(
add_dtype_suffix(kernel_name, graph.dtype_of(input));
add_dtype_suffix(kernel_name, graph.dtype_of(output));

float scale_val = static_cast<float>(graph.get_double(scale));
int zero_point_val = static_cast<int>(graph.get_int(zero_point));
int quant_min_val = static_cast<int>(graph.get_int(quant_min));
int quant_max_val = static_cast<int>(graph.get_int(quant_max));

Expand All @@ -100,23 +98,16 @@ void add_dequantize_per_tensor_node(
graph.strides_ubo(input),
graph.sizes_ubo(output),
graph.strides_ubo(output)};
push_constants = {
PushConstantDataInfo(&scale_val, sizeof(float)),
PushConstantDataInfo(&zero_point_val, sizeof(int)),
PushConstantDataInfo(&quant_min_val, sizeof(int)),
PushConstantDataInfo(&quant_max_val, sizeof(int)),
};
} else {
param_ubos = {
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
push_constants = {
PushConstantDataInfo(&scale_val, sizeof(float)),
PushConstantDataInfo(&zero_point_val, sizeof(int)),
PushConstantDataInfo(&quant_min_val, sizeof(int)),
PushConstantDataInfo(&quant_max_val, sizeof(int)),
};
}

push_constants = {
PushConstantDataInfo(&quant_min_val, sizeof(int)),
PushConstantDataInfo(&quant_max_val, sizeof(int)),
};

vkapi::SpecVarList spec_vars = {
graph.hashed_layout_of(output),
graph.hashed_layout_of(input),
Expand All @@ -128,7 +119,9 @@ void add_dequantize_per_tensor_node(
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{output, vkapi::kWrite}, {input, vkapi::kRead}},
{{output, vkapi::kWrite},
{input, vkapi::kRead},
{{scale, zero_point}, vkapi::kRead}},
// Shader param buffers
param_ubos,
// Push Constants
Expand Down Expand Up @@ -517,7 +510,7 @@ void dequantize_per_channel_impl(

REGISTER_OPERATORS {
VK_REGISTER_OP(
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
dequantize_per_tensor_impl);
VK_REGISTER_OP(
quantized_decomposed.dequantize_per_token.default,
Expand Down
Loading
Loading