@@ -85,8 +85,6 @@ void add_dequantize_per_tensor_node(
85
85
add_dtype_suffix (kernel_name, graph.dtype_of (input));
86
86
add_dtype_suffix (kernel_name, graph.dtype_of (output));
87
87
88
- float scale_val = static_cast <float >(graph.get_double (scale));
89
- int zero_point_val = static_cast <int >(graph.get_int (zero_point));
90
88
int quant_min_val = static_cast <int >(graph.get_int (quant_min));
91
89
int quant_max_val = static_cast <int >(graph.get_int (quant_max));
92
90
@@ -100,23 +98,16 @@ void add_dequantize_per_tensor_node(
100
98
graph.strides_ubo (input),
101
99
graph.sizes_ubo (output),
102
100
graph.strides_ubo (output)};
103
- push_constants = {
104
- PushConstantDataInfo (&scale_val, sizeof (float )),
105
- PushConstantDataInfo (&zero_point_val, sizeof (int )),
106
- PushConstantDataInfo (&quant_min_val, sizeof (int )),
107
- PushConstantDataInfo (&quant_max_val, sizeof (int )),
108
- };
109
101
} else {
110
102
param_ubos = {
111
103
graph.logical_limits_ubo (input), graph.logical_limits_ubo (output)};
112
- push_constants = {
113
- PushConstantDataInfo (&scale_val, sizeof (float )),
114
- PushConstantDataInfo (&zero_point_val, sizeof (int )),
115
- PushConstantDataInfo (&quant_min_val, sizeof (int )),
116
- PushConstantDataInfo (&quant_max_val, sizeof (int )),
117
- };
118
104
}
119
105
106
+ push_constants = {
107
+ PushConstantDataInfo (&quant_min_val, sizeof (int )),
108
+ PushConstantDataInfo (&quant_max_val, sizeof (int )),
109
+ };
110
+
120
111
vkapi::SpecVarList spec_vars = {
121
112
graph.hashed_layout_of (output),
122
113
graph.hashed_layout_of (input),
@@ -128,7 +119,9 @@ void add_dequantize_per_tensor_node(
128
119
default_pick_global_wg_size,
129
120
default_pick_local_wg_size,
130
121
// Inputs and Outputs
131
- {{output, vkapi::kWrite }, {input, vkapi::kRead }},
122
+ {{output, vkapi::kWrite },
123
+ {input, vkapi::kRead },
124
+ {{scale, zero_point}, vkapi::kRead }},
132
125
// Shader param buffers
133
126
param_ubos,
134
127
// Push Constants
@@ -517,7 +510,7 @@ void dequantize_per_channel_impl(
517
510
518
511
REGISTER_OPERATORS {
519
512
VK_REGISTER_OP (
520
- quantized_decomposed.dequantize_per_tensor .default ,
513
+ quantized_decomposed.dequantize_per_tensor .tensor ,
521
514
dequantize_per_tensor_impl);
522
515
VK_REGISTER_OP (
523
516
quantized_decomposed.dequantize_per_token .default ,
0 commit comments