@@ -87,8 +87,6 @@ void add_quantize_per_tensor_node(
87
87
add_dtype_suffix (kernel_name, graph.dtype_of (input));
88
88
add_dtype_suffix (kernel_name, graph.dtype_of (output));
89
89
90
- float scale_val = static_cast <float >(graph.get_double (scale));
91
- int zero_point_val = static_cast <int >(graph.get_int (zero_point));
92
90
int quant_min_val = static_cast <int >(graph.get_int (quant_min));
93
91
int quant_max_val = static_cast <int >(graph.get_int (quant_max));
94
92
@@ -102,23 +100,16 @@ void add_quantize_per_tensor_node(
102
100
graph.strides_ubo (input),
103
101
graph.sizes_ubo (output),
104
102
graph.strides_ubo (output)};
105
- push_constants = {
106
- PushConstantDataInfo (&scale_val, sizeof (float )),
107
- PushConstantDataInfo (&zero_point_val, sizeof (int )),
108
- PushConstantDataInfo (&quant_min_val, sizeof (int )),
109
- PushConstantDataInfo (&quant_max_val, sizeof (int )),
110
- };
111
103
} else {
112
104
param_ubos = {
113
105
graph.logical_limits_ubo (input), graph.logical_limits_ubo (output)};
114
- push_constants = {
115
- PushConstantDataInfo (&scale_val, sizeof (float )),
116
- PushConstantDataInfo (&zero_point_val, sizeof (int )),
117
- PushConstantDataInfo (&quant_min_val, sizeof (int )),
118
- PushConstantDataInfo (&quant_max_val, sizeof (int )),
119
- };
120
106
}
121
107
108
+ push_constants = {
109
+ PushConstantDataInfo (&quant_min_val, sizeof (int )),
110
+ PushConstantDataInfo (&quant_max_val, sizeof (int )),
111
+ };
112
+
122
113
vkapi::SpecVarList spec_vars = {
123
114
graph.hashed_layout_of (output),
124
115
graph.hashed_layout_of (input),
@@ -130,7 +121,9 @@ void add_quantize_per_tensor_node(
130
121
default_pick_global_wg_size,
131
122
default_pick_local_wg_size,
132
123
// Inputs and Outputs
133
- {{output, vkapi::kWrite }, {input, vkapi::kRead }},
124
+ {{output, vkapi::kWrite },
125
+ {input, vkapi::kRead },
126
+ {{scale, zero_point}, vkapi::kRead }},
134
127
// Shader param buffers
135
128
param_ubos,
136
129
// Push Constants
@@ -489,7 +482,7 @@ void quantize_per_channel_impl(
489
482
490
483
REGISTER_OPERATORS {
491
484
VK_REGISTER_OP (
492
- quantized_decomposed.quantize_per_tensor .default ,
485
+ quantized_decomposed.quantize_per_tensor .tensor ,
493
486
quantize_per_tensor_impl);
494
487
VK_REGISTER_OP (
495
488
quantized_decomposed.quantize_per_token .default , quantize_per_token_impl);
0 commit comments