Skip to content

Commit 87b2d64

Browse files
authored
[ET-VK][Ops] quantize_per_tensor.tensor variant
Differential Revision: D77746136 Pull Request resolved: #12208
1 parent 231d5ba commit 87b2d64

File tree

4 files changed

+318
-260
lines changed

4 files changed

+318
-260
lines changed

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.glsl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "buffer")}
2727
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2828

2929
$if MODE == "per_tensor":
30+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
31+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
32+
3033
layout(push_constant) uniform restrict Block {
31-
float scale;
32-
int zero_point;
3334
int quant_min;
3435
int quant_max;
3536
};
@@ -142,7 +143,7 @@ void quantize_per_tensor() {
142143
const int in_bufi = tidx_to_bufi(out_tidx, t_in_strides);
143144

144145
IN_T value = t_in[in_bufi];
145-
OUT_T qvalue = quantize_val(value, scale, zero_point);
146+
OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]);
146147

147148
t_out[out_bufi] = qvalue;
148149
}

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.glsl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ ${layout_declare_tensor(B, "w", "t_out", OUT_DTYPE, "texture3d")}
3232
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
3333

3434
$if MODE == "per_tensor":
35+
${layout_declare_tensor(B, "r", "t_scale", "float", "buffer")}
36+
${layout_declare_tensor(B, "r", "t_zero_point", "int", "buffer")}
37+
3538
layout(push_constant) uniform restrict Block {
36-
float scale;
37-
int zero_point;
3839
int quant_min;
3940
int quant_max;
4041
};
@@ -146,7 +147,7 @@ void quantize_per_tensor() {
146147

147148
[[unroll]] for (int i = 0; i < 4; ++i) {
148149
IN_T value = IN_T(intex[i]);
149-
OUT_T qvalue = quantize_val(value, scale, zero_point);
150+
OUT_T qvalue = quantize_val(value, t_scale[0], t_zero_point[0]);
150151
outtex[i] = qvalue;
151152
}
152153
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,6 @@ void add_quantize_per_tensor_node(
8787
add_dtype_suffix(kernel_name, graph.dtype_of(input));
8888
add_dtype_suffix(kernel_name, graph.dtype_of(output));
8989

90-
float scale_val = static_cast<float>(graph.get_double(scale));
91-
int zero_point_val = static_cast<int>(graph.get_int(zero_point));
9290
int quant_min_val = static_cast<int>(graph.get_int(quant_min));
9391
int quant_max_val = static_cast<int>(graph.get_int(quant_max));
9492

@@ -102,23 +100,16 @@ void add_quantize_per_tensor_node(
102100
graph.strides_ubo(input),
103101
graph.sizes_ubo(output),
104102
graph.strides_ubo(output)};
105-
push_constants = {
106-
PushConstantDataInfo(&scale_val, sizeof(float)),
107-
PushConstantDataInfo(&zero_point_val, sizeof(int)),
108-
PushConstantDataInfo(&quant_min_val, sizeof(int)),
109-
PushConstantDataInfo(&quant_max_val, sizeof(int)),
110-
};
111103
} else {
112104
param_ubos = {
113105
graph.logical_limits_ubo(input), graph.logical_limits_ubo(output)};
114-
push_constants = {
115-
PushConstantDataInfo(&scale_val, sizeof(float)),
116-
PushConstantDataInfo(&zero_point_val, sizeof(int)),
117-
PushConstantDataInfo(&quant_min_val, sizeof(int)),
118-
PushConstantDataInfo(&quant_max_val, sizeof(int)),
119-
};
120106
}
121107

108+
push_constants = {
109+
PushConstantDataInfo(&quant_min_val, sizeof(int)),
110+
PushConstantDataInfo(&quant_max_val, sizeof(int)),
111+
};
112+
122113
vkapi::SpecVarList spec_vars = {
123114
graph.hashed_layout_of(output),
124115
graph.hashed_layout_of(input),
@@ -130,7 +121,9 @@ void add_quantize_per_tensor_node(
130121
default_pick_global_wg_size,
131122
default_pick_local_wg_size,
132123
// Inputs and Outputs
133-
{{output, vkapi::kWrite}, {input, vkapi::kRead}},
124+
{{output, vkapi::kWrite},
125+
{input, vkapi::kRead},
126+
{{scale, zero_point}, vkapi::kRead}},
134127
// Shader param buffers
135128
param_ubos,
136129
// Push Constants
@@ -489,7 +482,7 @@ void quantize_per_channel_impl(
489482

490483
REGISTER_OPERATORS {
491484
VK_REGISTER_OP(
492-
quantized_decomposed.quantize_per_tensor.default,
485+
quantized_decomposed.quantize_per_tensor.tensor,
493486
quantize_per_tensor_impl);
494487
VK_REGISTER_OP(
495488
quantized_decomposed.quantize_per_token.default, quantize_per_token_impl);

0 commit comments

Comments
 (0)