Skip to content

Commit b2375ce

Browse files
authored
[ET-VK][Ops] aligning Q/DQ/CQP op inputs with ATen impl
Differential Revision: D77746144 Pull Request resolved: #12199
1 parent 9a362f7 commit b2375ce

File tree

9 files changed

+169
-56
lines changed

9 files changed

+169
-56
lines changed

backends/vulkan/runtime/graph/ops/glsl/choose_qparams.glslh

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,13 @@
99
#ifndef CHOOSE_QPARAMS_GLSLH
1010
#define CHOOSE_QPARAMS_GLSLH
1111

12-
// equivalent of the eps defined in the cpu implementation
13-
#define SMALL_SCALE_THRESHOLD 6.1e-5
14-
1512
// Calculate scale and zero point from min and max values
1613
void calculate_scale_and_zero_point(
1714
float min_val,
1815
float max_val,
1916
int qmin,
2017
int qmax,
18+
float eps_threshold,
2119
out float scale_val,
2220
out int zero_point_val) {
2321
// ensure we have zero included in our range
@@ -31,18 +29,18 @@ void calculate_scale_and_zero_point(
3129
scale_val = 0.1;
3230
}
3331

34-
// Cut off small scale
35-
if (scale_val < SMALL_SCALE_THRESHOLD) {
32+
// Cut off small scale using the provided eps threshold
33+
if (scale_val < eps_threshold) {
3634
float org_scale = scale_val;
37-
scale_val = SMALL_SCALE_THRESHOLD;
35+
scale_val = eps_threshold;
3836

3937
// Adjust min and max based on new scale
4038
if (min_val == 0.0) {
41-
max_val = SMALL_SCALE_THRESHOLD * float(qmax - qmin);
39+
max_val = eps_threshold * float(qmax - qmin);
4240
} else if (max_val == 0.0) {
43-
min_val = -SMALL_SCALE_THRESHOLD * float(qmax - qmin);
41+
min_val = -eps_threshold * float(qmax - qmin);
4442
} else {
45-
float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
43+
float amplifier = eps_threshold / org_scale;
4644
min_val *= amplifier;
4745
max_val *= amplifier;
4846
}

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ $if MODE == "per_tensor":
2929
layout(push_constant) uniform restrict Block {
3030
int quant_min;
3131
int quant_max;
32+
float eps;
3233
};
3334
$else:
3435
layout(push_constant) uniform restrict Block {
@@ -175,7 +176,7 @@ void choose_qparams_per_tensor() {
175176

176177
float scale_val;
177178
int zero_point_val;
178-
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);
179+
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val);
179180

180181
t_scale[0] = scale_val;
181182
t_zero_point[0] = zero_point_val;
@@ -260,7 +261,7 @@ void choose_qparams_per_token() {
260261

261262
float scale_val;
262263
int zero_point_val;
263-
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);
264+
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val);
264265

265266
t_scale[token_id] = scale_val;
266267
t_zero_point[token_id] = zero_point_val;

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ $if MODE == "per_tensor":
3030
layout(push_constant) uniform restrict Block {
3131
int quant_min;
3232
int quant_max;
33+
float eps;
3334
};
3435
$else:
3536
layout(push_constant) uniform restrict Block {
@@ -234,7 +235,7 @@ void choose_qparams_per_tensor() {
234235

235236
float scale_val;
236237
int zero_point_val;
237-
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, scale_val, zero_point_val);
238+
calculate_scale_and_zero_point(global_min, global_max, quant_min, quant_max, eps, scale_val, zero_point_val);
238239

239240
write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0));
240241
write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0));
@@ -372,7 +373,7 @@ void choose_qparams_per_token() {
372373

373374
float scale_val;
374375
int zero_point_val;
375-
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, scale_val, zero_point_val);
376+
calculate_scale_and_zero_point(token_min, token_max, quant_min, quant_max, 1e-5, scale_val, zero_point_val);
376377

377378
// Convert token_id to 3D coordinates for output texture
378379
// Assuming output tensors have the same layout as input but with different dimensions

backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ void add_choose_qparams_tensor_node(
150150
const ValueRef& input,
151151
const ValueRef& quant_min,
152152
const ValueRef& quant_max,
153+
const ValueRef& eps,
153154
const ValueRef& scale_out,
154155
const ValueRef& zero_point_out) {
155156
std::string kernel_name("choose_qparams_tensor");
@@ -158,6 +159,7 @@ void add_choose_qparams_tensor_node(
158159

159160
int quant_min_val = static_cast<int>(graph.get_int(quant_min));
160161
int quant_max_val = static_cast<int>(graph.get_int(quant_max));
162+
float eps_val = static_cast<float>(graph.get_double(eps));
161163

162164
vkapi::ParamsBindList param_ubos;
163165

@@ -180,6 +182,7 @@ void add_choose_qparams_tensor_node(
180182
push_constants = {
181183
PushConstantDataInfo(&quant_min_val, sizeof(int)),
182184
PushConstantDataInfo(&quant_max_val, sizeof(int)),
185+
PushConstantDataInfo(&eps_val, sizeof(float)),
183186
};
184187

185188
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
@@ -275,8 +278,22 @@ void choose_qparams_tensor_impl(
275278
const ValueRef input = args[arg_idx++];
276279
const ValueRef quant_min = args[arg_idx++];
277280
const ValueRef quant_max = args[arg_idx++];
278-
const ValueRef scale_out = args[arg_idx++];
279-
const ValueRef zero_point_out = args[arg_idx++];
281+
const ValueRef eps = args[arg_idx++]; // Added eps parameter (will be voided)
282+
const ValueRef dtype =
283+
args[arg_idx++]; // Added dtype parameter (will be voided)
284+
const ValueRef out_tuple_ref = args[arg_idx++];
285+
286+
ValueRef scale_out = kDummyValueRef;
287+
ValueRef zero_point_out = kDummyValueRef;
288+
289+
{
290+
const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref);
291+
scale_out = out_tuple->at(0);
292+
zero_point_out = out_tuple->at(1);
293+
}
294+
295+
// Void the unused dtype parameter to match ATen signature
296+
(void)dtype;
280297

281298
// Check tensor types
282299
VK_CHECK_COND(graph.val_is_tensor(input));
@@ -289,30 +306,40 @@ void choose_qparams_tensor_impl(
289306
graph.dtype_of(input) == vkapi::kHalf ||
290307
graph.dtype_of(input) == vkapi::kDouble);
291308

292-
// Verify output types - accept CPU types but convert to GPU types
293-
VK_CHECK_COND(
294-
graph.dtype_of(scale_out) == vkapi::kFloat ||
295-
graph.dtype_of(scale_out) == vkapi::kDouble);
296-
VK_CHECK_COND(
297-
graph.dtype_of(zero_point_out) == vkapi::kInt ||
298-
graph.dtype_of(zero_point_out) == vkapi::kLong);
309+
// Verify output types - only accept Vulkan-supported types
310+
// The Vulkan backend only supports float32 and int32, not float64/int64
311+
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
312+
VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt);
299313

300314
// Check that texture storage is width packed
301315
if (!graph.is_buffer_storage(input)) {
302316
VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim);
303317
}
304318

305319
add_choose_qparams_tensor_node(
306-
graph, input, quant_min, quant_max, scale_out, zero_point_out);
320+
graph, input, quant_min, quant_max, eps, scale_out, zero_point_out);
307321
}
308322

309323
void choose_qparams_per_token_asymmetric_impl(
310324
ComputeGraph& graph,
311325
const std::vector<ValueRef>& args) {
312326
int arg_idx = 0;
313327
const ValueRef input = args[arg_idx++];
314-
const ValueRef scale_out = args[arg_idx++];
315-
const ValueRef zero_point_out = args[arg_idx++];
328+
const ValueRef dtype =
329+
args[arg_idx++]; // Added dtype parameter (will be voided)
330+
const ValueRef out_tuple_ref = args[arg_idx++];
331+
332+
ValueRef scale_out = kDummyValueRef;
333+
ValueRef zero_point_out = kDummyValueRef;
334+
335+
{
336+
const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref);
337+
scale_out = out_tuple->at(0);
338+
zero_point_out = out_tuple->at(1);
339+
}
340+
341+
// Void the unused parameter to match ATen signature
342+
(void)dtype;
316343

317344
// Check tensor types
318345
VK_CHECK_COND(graph.val_is_tensor(input));
@@ -325,22 +352,20 @@ void choose_qparams_per_token_asymmetric_impl(
325352
graph.dtype_of(input) == vkapi::kHalf ||
326353
graph.dtype_of(input) == vkapi::kDouble);
327354

328-
// Verify output types - accept CPU types but convert to GPU types
329-
VK_CHECK_COND(
330-
graph.dtype_of(scale_out) == vkapi::kFloat ||
331-
graph.dtype_of(scale_out) == vkapi::kDouble);
332-
VK_CHECK_COND(
333-
graph.dtype_of(zero_point_out) == vkapi::kInt ||
334-
graph.dtype_of(zero_point_out) == vkapi::kLong);
355+
// Verify output types - only accept Vulkan-supported types
356+
// The Vulkan backend only supports float32 and int32, not float64/int64
357+
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
358+
VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt);
335359

336360
add_choose_qparams_per_token_asymmetric_node(
337361
graph, input, scale_out, zero_point_out);
338362
}
339363

340364
REGISTER_OPERATORS {
341-
VK_REGISTER_OP(choose_qparams.tensor, choose_qparams_tensor_impl);
342365
VK_REGISTER_OP(
343-
choose_qparams_per_token_asymmetric.default,
366+
quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl);
367+
VK_REGISTER_OP(
368+
quantized_decomposed.choose_qparams_per_token_asymmetric.default,
344369
choose_qparams_per_token_asymmetric_impl);
345370
}
346371

backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,15 @@ void dequantize_per_tensor_impl(
180180
const ValueRef zero_point = args[arg_idx++];
181181
const ValueRef quant_min = args[arg_idx++];
182182
const ValueRef quant_max = args[arg_idx++];
183+
const ValueRef dtype = args[arg_idx++]; // Added dtype parameter
184+
const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter
183185
const ValueRef output = args[arg_idx++];
184186

187+
// Suppress unused variable warnings - dtype and output_dtype are inferred
188+
// from output
189+
(void)dtype;
190+
(void)output_dtype;
191+
185192
// Check tensor types
186193
VK_CHECK_COND(graph.val_is_tensor(input));
187194
VK_CHECK_COND(graph.val_is_tensor(output));
@@ -212,8 +219,15 @@ void dequantize_per_token_impl(
212219
const ValueRef zero_point = args[arg_idx++];
213220
const ValueRef quant_min = args[arg_idx++];
214221
const ValueRef quant_max = args[arg_idx++];
222+
const ValueRef dtype = args[arg_idx++]; // Added dtype parameter
223+
const ValueRef output_dtype = args[arg_idx++]; // Added output_dtype parameter
215224
const ValueRef output = args[arg_idx++];
216225

226+
// Suppress unused variable warnings - dtype and output_dtype are inferred
227+
// from output
228+
(void)dtype;
229+
(void)output_dtype;
230+
217231
// Check tensor types
218232
VK_CHECK_COND(graph.val_is_tensor(input));
219233
VK_CHECK_COND(graph.val_is_tensor(scale));
@@ -257,18 +271,34 @@ void dequantize_per_token_impl(
257271
const auto scale_sizes = graph.sizes_of(scale);
258272
const auto zero_point_sizes = graph.sizes_of(zero_point);
259273

260-
VK_CHECK_COND(scale_sizes.size() == 1);
261-
VK_CHECK_COND(zero_point_sizes.size() == 1);
262-
VK_CHECK_COND(scale_sizes[0] == num_tokens);
263-
VK_CHECK_COND(zero_point_sizes[0] == num_tokens);
274+
// Calculate total number of elements in scale and zero_point tensors
275+
int64_t scale_numel = 1;
276+
for (size_t i = 0; i < scale_sizes.size(); i++) {
277+
scale_numel *= scale_sizes[i];
278+
}
279+
280+
int64_t zero_point_numel = 1;
281+
for (size_t i = 0; i < zero_point_sizes.size(); i++) {
282+
zero_point_numel *= zero_point_sizes[i];
283+
}
284+
285+
// Check that the total number of elements matches num_tokens
286+
// This allows for both 1D tensors (size [num_tokens]) and reshaped tensors
287+
// (size [num_tokens, 1])
288+
VK_CHECK_COND(scale_numel == num_tokens);
289+
VK_CHECK_COND(zero_point_numel == num_tokens);
264290

265291
add_dequantize_per_token_node(
266292
graph, input, scale, zero_point, quant_min, quant_max, output);
267293
}
268294

269295
REGISTER_OPERATORS {
270-
VK_REGISTER_OP(dequantize_per_tensor.default, dequantize_per_tensor_impl);
271-
VK_REGISTER_OP(dequantize_per_token.default, dequantize_per_token_impl);
296+
VK_REGISTER_OP(
297+
quantized_decomposed.dequantize_per_tensor.default,
298+
dequantize_per_tensor_impl);
299+
VK_REGISTER_OP(
300+
quantized_decomposed.dequantize_per_token.default,
301+
dequantize_per_token_impl);
272302
}
273303

274304
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,12 @@ void quantize_per_tensor_impl(
180180
const ValueRef zero_point = args[arg_idx++];
181181
const ValueRef quant_min = args[arg_idx++];
182182
const ValueRef quant_max = args[arg_idx++];
183+
const ValueRef dtype = args[arg_idx++]; // Added dtype parameter
183184
const ValueRef output = args[arg_idx++];
184185

186+
// Suppress unused variable warning - dtype is inferred from output
187+
(void)dtype;
188+
185189
// Check tensor types
186190
VK_CHECK_COND(graph.val_is_tensor(input));
187191
VK_CHECK_COND(graph.val_is_tensor(output));
@@ -205,8 +209,12 @@ void quantize_per_token_impl(
205209
const ValueRef zero_point = args[arg_idx++];
206210
const ValueRef quant_min = args[arg_idx++];
207211
const ValueRef quant_max = args[arg_idx++];
212+
const ValueRef dtype = args[arg_idx++]; // Added dtype parameter
208213
const ValueRef output = args[arg_idx++];
209214

215+
// Suppress unused variable warning - dtype is inferred from output
216+
(void)dtype;
217+
210218
// Check tensor types
211219
VK_CHECK_COND(graph.val_is_tensor(input));
212220
VK_CHECK_COND(graph.val_is_tensor(scale));
@@ -243,18 +251,33 @@ void quantize_per_token_impl(
243251
const auto scale_sizes = graph.sizes_of(scale);
244252
const auto zero_point_sizes = graph.sizes_of(zero_point);
245253

246-
VK_CHECK_COND(scale_sizes.size() == 1);
247-
VK_CHECK_COND(zero_point_sizes.size() == 1);
248-
VK_CHECK_COND(scale_sizes[0] == num_tokens);
249-
VK_CHECK_COND(zero_point_sizes[0] == num_tokens);
254+
// Calculate total number of elements in scale and zero_point tensors
255+
int64_t scale_numel = 1;
256+
for (size_t i = 0; i < scale_sizes.size(); i++) {
257+
scale_numel *= scale_sizes[i];
258+
}
259+
260+
int64_t zero_point_numel = 1;
261+
for (size_t i = 0; i < zero_point_sizes.size(); i++) {
262+
zero_point_numel *= zero_point_sizes[i];
263+
}
264+
265+
// Check that the total number of elements matches num_tokens
266+
// This allows for both 1D tensors (size [num_tokens]) and reshaped tensors
267+
// (size [num_tokens, 1])
268+
VK_CHECK_COND(scale_numel == num_tokens);
269+
VK_CHECK_COND(zero_point_numel == num_tokens);
250270

251271
add_quantize_per_token_node(
252272
graph, input, scale, zero_point, quant_min, quant_max, output);
253273
}
254274

255275
REGISTER_OPERATORS {
256-
VK_REGISTER_OP(quantize_per_tensor.default, quantize_per_tensor_impl);
257-
VK_REGISTER_OP(quantize_per_token.default, quantize_per_token_impl);
276+
VK_REGISTER_OP(
277+
quantized_decomposed.quantize_per_tensor.default,
278+
quantize_per_tensor_impl);
279+
VK_REGISTER_OP(
280+
quantized_decomposed.quantize_per_token.default, quantize_per_token_impl);
258281
}
259282

260283
} // namespace vkcompute

0 commit comments

Comments
 (0)