Skip to content

Commit 8f1121b

Browse files
authored
[ET-VK][Ops] dequantize_per_tensor.tensor variant
Differential Revision: D77746135 Pull Request resolved: #12209
1 parent 8d4e471 commit 8f1121b

File tree

4 files changed

+386
-319
lines changed

4 files changed

+386
-319
lines changed

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.glsl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
2727
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2828

2929
$if MODE == "per_tensor":
30+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
31+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
32+
3033
layout(push_constant) uniform restrict Block {
31-
float scale;
32-
int zero_point;
3334
int quant_min;
3435
int quant_max;
3536
};
@@ -146,7 +147,7 @@ void dequantize_per_tensor() {
146147
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
147148

148149
IN_T qvalue = t_in[in_bufi];
149-
OUT_T value = dequantize_val(qvalue, scale, zero_point);
150+
OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]);
150151

151152
t_out[out_bufi] = value;
152153
}

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
3030
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
3131

3232
$if MODE == "per_tensor":
33+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
34+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
35+
3336
layout(push_constant) uniform restrict Block {
34-
float scale;
35-
int zero_point;
3637
int quant_min;
3738
int quant_max;
3839
};
@@ -148,7 +149,8 @@ void dequantize_per_tensor() {
148149

149150
[[unroll]] for (int i = 0; i < 4; ++i) {
150151
IN_T qvalue = IN_T(intex[i]);
151-
OUT_T value = dequantize_val(qvalue, scale, zero_point);
152+
OUT_T value = dequantize_val(qvalue, t_scale[0], t_zero_point[0]);
153+
152154
$if OUT_DTYPE == "double":
153155
outtex[i] = float(value);
154156
$else:

backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ void add_dequantize_per_tensor_node(
8585
add_dtype_suffix(kernel_name, graph.dtype_of(input));
8686
add_dtype_suffix(kernel_name, graph.dtype_of(output));
8787

88-
float scale_val = static_cast<float>(graph.get_double(scale));
89-
int zero_point_val = static_cast<int>(graph.get_int(zero_point));
9088
int quant_min_val = static_cast<int>(graph.get_int(quant_min));
9189
int quant_max_val = static_cast<int>(graph.get_int(quant_max));
9290

@@ -100,23 +98,16 @@ void add_dequantize_per_tensor_node(
10098
graph.strides_ubo(input),
10199
graph.sizes_ubo(output),
102100
graph.strides_ubo(output)};
103-
push_constants = {
104-
PushConstantDataInfo(&scale_val, sizeof(float)),
105-
PushConstantDataInfo(&zero_point_val, sizeof(int)),
106-
PushConstantDataInfo(&quant_min_val, sizeof(int)),
107-
PushConstantDataInfo(&quant_max_val, sizeof(int)),
108-
};
109101
} else {
110102
param_ubos = {
111103
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
112-
push_constants = {
113-
PushConstantDataInfo(&scale_val, sizeof(float)),
114-
PushConstantDataInfo(&zero_point_val, sizeof(int)),
115-
PushConstantDataInfo(&quant_min_val, sizeof(int)),
116-
PushConstantDataInfo(&quant_max_val, sizeof(int)),
117-
};
118104
}
119105

106+
push_constants = {
107+
PushConstantDataInfo(&quant_min_val, sizeof(int)),
108+
PushConstantDataInfo(&quant_max_val, sizeof(int)),
109+
};
110+
120111
vkapi::SpecVarList spec_vars = {
121112
graph.hashed_layout_of(output),
122113
graph.hashed_layout_of(input),
@@ -128,7 +119,9 @@ void add_dequantize_per_tensor_node(
128119
default_pick_global_wg_size,
129120
default_pick_local_wg_size,
130121
// Inputs and Outputs
131-
{{output, vkapi::kWrite}, {input, vkapi::kRead}},
122+
{{output, vkapi::kWrite},
123+
{input, vkapi::kRead},
124+
{{scale, zero_point}, vkapi::kRead}},
132125
// Shader param buffers
133126
param_ubos,
134127
// Push Constants
@@ -517,7 +510,7 @@ void dequantize_per_channel_impl(
517510

518511
REGISTER_OPERATORS {
519512
VK_REGISTER_OP(
520-
quantized_decomposed.dequantize_per_tensor.default,
513+
quantized_decomposed.dequantize_per_tensor.tensor,
521514
dequantize_per_tensor_impl);
522515
VK_REGISTER_OP(
523516
quantized_decomposed.dequantize_per_token.default,

0 commit comments

Comments
 (0)