Skip to content

Commit cc83ebc

Browse files
author
morelos
committed
[ET-VK][Ops] affine quantization operators registration
Pull Request resolved: #12369 # Context In order to enable dynamic quantization, especially for the source transform method using `Int8DynActInt4WeightQuantizer` we need to have vulkan versions for `quantize_affine`, `dequantize_affine`, and `choose_qparams_affine`. Currently we do not have a shader that performs block-based quantization as expected from these shaders, so we delegate to the per_tensor variant just to get unblocked. At a later stage, this will likely be developed more on in order to ensure we don't get too much accuracy loss. A full implementation for the affine operators will be done at a later time, since they are required for usage. However, if you wan't to just use the default as per_tensor then you must remove the checks made in `op_registry` and in the vulkan implementation so that the per_tensor version can be used. Without it they will not be registered. # Changes This creates a schema reference in the TorchAO library for out variants of these respective operators. Then there is a VK_REGISTER_OP done on them to ensure that we can properly register them when lowering the ET model with vulkan. We also changed `Linear.cpp`, particularly to allow a passthrough for weight_data since during dynamic quantization it's possible that it'll be a tensor_data than tensor_ref. ghstack-source-id: 295746674 @exported-using-ghexport Differential Revision: [D78035354](https://our.internmc.facebook.com/intern/diff/D78035354/)
1 parent e189b7a commit cc83ebc

File tree

5 files changed

+239
-7
lines changed

5 files changed

+239
-7
lines changed

backends/vulkan/op_registry.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,60 @@ def register_quantization_op(features: OpFeatures):
272272
return features
273273

274274

275+
@update_features(
276+
[
277+
exir_ops.edge.torchao.quantize_affine.default,
278+
exir_ops.edge.torchao.dequantize_affine.default,
279+
exir_ops.edge.torchao.choose_qparams_affine.default,
280+
]
281+
)
282+
def register_torchao_quantization_op(features: OpFeatures):
283+
# TorchAO quantization operators - default to per-tensor behavior
284+
# Same features as standard quantization ops
285+
features.texture_impl = TextureImplFeatures(
286+
uses_axis_map=True,
287+
valid_packed_dims={
288+
PackedDim.WIDTH,
289+
},
290+
)
291+
features.buffer_impl = True
292+
features.resize_fn = True
293+
features.optimal_storage = VkStorageType.BUFFER
294+
295+
def check_torchao_quantization_node(node: torch.fx.Node) -> bool:
296+
# Only per-tensor quantization is supported by the Vulkan backend.
297+
if len(node.args) < 2:
298+
return False
299+
300+
block_size = node.args[1]
301+
302+
if not isinstance(block_size, (list, tuple)):
303+
return False
304+
305+
input_arg = node.args[0]
306+
if not isinstance(input_arg, torch.fx.Node):
307+
return False
308+
309+
input_tensor = input_arg.meta.get("val", None)
310+
if not isinstance(input_tensor, FakeTensor):
311+
return False
312+
313+
input_shape = list(input_tensor.shape)
314+
315+
if len(block_size) != len(input_shape):
316+
return False
317+
318+
# Check if block_size matches input_shape exactly (per-tensor quantization)
319+
for i in range(len(block_size)):
320+
if block_size[i] != input_shape[i]:
321+
return False
322+
323+
return True
324+
325+
features.check_node_fn = check_torchao_quantization_node
326+
return features
327+
328+
275329
@update_features(
276330
[
277331
exir_ops.edge.aten.add.Tensor,

backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,12 @@ void choose_qparams_tensor_impl(
306306
graph.dtype_of(input) == vkapi::kHalf ||
307307
graph.dtype_of(input) == vkapi::kDouble);
308308

309-
// Verify output types - only accept Vulkan-supported types
310-
// The Vulkan backend only supports float32 and int32, not float64/int64
309+
// Verify output types - accept both int32 and float32 for zero_point
310+
// TorchAO may use float32 for zero_point in some cases
311311
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
312-
VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt);
312+
VK_CHECK_COND(
313+
graph.dtype_of(zero_point_out) == vkapi::kInt ||
314+
graph.dtype_of(zero_point_out) == vkapi::kFloat);
313315

314316
// Check that texture storage is width packed
315317
if (!graph.is_buffer_storage(input)) {
@@ -352,21 +354,96 @@ void choose_qparams_per_token_asymmetric_impl(
352354
graph.dtype_of(input) == vkapi::kHalf ||
353355
graph.dtype_of(input) == vkapi::kDouble);
354356

355-
// Verify output types - only accept Vulkan-supported types
356-
// The Vulkan backend only supports float32 and int32, not float64/int64
357+
// Verify output types - accept both int32 and float32 for zero_point
358+
// TorchAO may use float32 for zero_point in some cases
357359
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
358-
VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt);
360+
VK_CHECK_COND(
361+
graph.dtype_of(zero_point_out) == vkapi::kInt ||
362+
graph.dtype_of(zero_point_out) == vkapi::kFloat);
359363

360364
add_choose_qparams_per_token_asymmetric_node(
361365
graph, input, scale_out, zero_point_out);
362366
}
363367

368+
void choose_qparams_affine_impl(
369+
ComputeGraph& graph,
370+
const std::vector<ValueRef>& args) {
371+
int arg_idx = 0;
372+
const ValueRef input = args[arg_idx++];
373+
const ValueRef mapping_type = args[arg_idx++]; // str - ignored for per-tensor
374+
const ValueRef block_size =
375+
args[arg_idx++]; // SymInt[] - ignored for per-tensor
376+
const ValueRef target_dtype = args[arg_idx++];
377+
const ValueRef quant_min = args[arg_idx++];
378+
const ValueRef quant_max = args[arg_idx++];
379+
const ValueRef eps = args[arg_idx++];
380+
const ValueRef scale_dtype = args[arg_idx++];
381+
const ValueRef zero_point_dtype = args[arg_idx++];
382+
const ValueRef out_tuple_ref = args[arg_idx++];
383+
384+
// Suppress unused variable warnings
385+
(void)mapping_type;
386+
(void)target_dtype;
387+
(void)scale_dtype;
388+
(void)zero_point_dtype;
389+
390+
ValueRef scale_out = kDummyValueRef;
391+
ValueRef zero_point_out = kDummyValueRef;
392+
393+
{
394+
const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref);
395+
scale_out = out_tuple->at(0);
396+
zero_point_out = out_tuple->at(1);
397+
}
398+
399+
// Check tensor types
400+
VK_CHECK_COND(graph.val_is_tensor(input));
401+
VK_CHECK_COND(graph.val_is_tensor(scale_out));
402+
VK_CHECK_COND(graph.val_is_tensor(zero_point_out));
403+
404+
// Verify input is a floating point type
405+
VK_CHECK_COND(
406+
graph.dtype_of(input) == vkapi::kFloat ||
407+
graph.dtype_of(input) == vkapi::kHalf ||
408+
graph.dtype_of(input) == vkapi::kDouble);
409+
410+
// Verify output types - accept both int32 and float32 for zero_point
411+
// TorchAO may use float32 for zero_point in some cases
412+
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
413+
VK_CHECK_COND(
414+
graph.dtype_of(zero_point_out) == vkapi::kInt ||
415+
graph.dtype_of(zero_point_out) == vkapi::kFloat);
416+
417+
// Check if this is per-tensor quantization (only supported granularity)
418+
// block_size should equal input tensor dimensions for per-tensor quantization
419+
const auto input_sizes = graph.sizes_of(input);
420+
const auto block_size_list = graph.get_int_list(block_size);
421+
VK_CHECK_COND(block_size_list->size() == input_sizes.size());
422+
for (size_t i = 0; i < input_sizes.size(); i++) {
423+
VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]);
424+
}
425+
426+
// Check that texture storage is width packed
427+
if (!graph.is_buffer_storage(input)) {
428+
VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim);
429+
}
430+
431+
// Default to per-tensor quantization parameter calculation for TorchAO affine
432+
// ops
433+
add_choose_qparams_tensor_node(
434+
graph, input, quant_min, quant_max, eps, scale_out, zero_point_out);
435+
}
436+
364437
REGISTER_OPERATORS {
365438
VK_REGISTER_OP(
366439
quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl);
367440
VK_REGISTER_OP(
368441
quantized_decomposed.choose_qparams_per_token_asymmetric.default,
369442
choose_qparams_per_token_asymmetric_impl);
443+
444+
// TorchAO affine choose_qparams operators
445+
VK_REGISTER_OP(
446+
torchao.choose_qparams_affine.default, choose_qparams_affine_impl);
370447
}
371448

372449
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,56 @@ void dequantize_per_channel_impl(
508508
graph, input, scale, zero_point, axis, quant_min, quant_max, output);
509509
}
510510

511+
void dequantize_affine_impl(
512+
ComputeGraph& graph,
513+
const std::vector<ValueRef>& args) {
514+
int arg_idx = 0;
515+
const ValueRef input = args[arg_idx++];
516+
const ValueRef block_size =
517+
args[arg_idx++]; // SymInt[] - ignored for per-tensor
518+
const ValueRef scale = args[arg_idx++];
519+
const ValueRef zero_point = args[arg_idx++];
520+
const ValueRef input_dtype = args[arg_idx++];
521+
const ValueRef quant_min = args[arg_idx++];
522+
const ValueRef quant_max = args[arg_idx++];
523+
const ValueRef output_dtype = args[arg_idx++];
524+
const ValueRef output = args[arg_idx++];
525+
526+
// Suppress unused variable warnings
527+
(void)input_dtype;
528+
(void)output_dtype;
529+
530+
// Check tensor types
531+
VK_CHECK_COND(graph.val_is_tensor(input));
532+
VK_CHECK_COND(graph.val_is_tensor(output));
533+
534+
// Verify input is an integer type
535+
VK_CHECK_COND(
536+
graph.dtype_of(input) == vkapi::kByte ||
537+
graph.dtype_of(input) == vkapi::kChar ||
538+
graph.dtype_of(input) == vkapi::kShort ||
539+
graph.dtype_of(input) == vkapi::kInt);
540+
541+
// Verify output is a floating point type
542+
VK_CHECK_COND(
543+
graph.dtype_of(output) == vkapi::kHalf ||
544+
graph.dtype_of(output) == vkapi::kFloat ||
545+
graph.dtype_of(output) == vkapi::kDouble);
546+
547+
// Check if this is per-tensor quantization (only supported granularity)
548+
// block_size should equal input tensor dimensions for per-tensor quantization
549+
const auto input_sizes = graph.sizes_of(input);
550+
const auto block_size_list = graph.get_int_list(block_size);
551+
VK_CHECK_COND(block_size_list->size() == input_sizes.size());
552+
for (size_t i = 0; i < input_sizes.size(); i++) {
553+
VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]);
554+
}
555+
556+
// Default to per-tensor dequantization for TorchAO affine ops
557+
add_dequantize_per_tensor_node(
558+
graph, input, scale, zero_point, quant_min, quant_max, output);
559+
}
560+
511561
REGISTER_OPERATORS {
512562
VK_REGISTER_OP(
513563
quantized_decomposed.dequantize_per_tensor.tensor,
@@ -518,6 +568,9 @@ REGISTER_OPERATORS {
518568
VK_REGISTER_OP(
519569
quantized_decomposed.dequantize_per_channel.default,
520570
dequantize_per_channel_impl);
571+
572+
// TorchAO affine dequantization operators
573+
VK_REGISTER_OP(torchao.dequantize_affine.default, dequantize_affine_impl);
521574
}
522575

523576
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,11 @@ void linear(ComputeGraph& graph, const std::vector<ValueRef>& args) {
351351
ValueRef bias = args.at(2);
352352
ValueRef out = args.at(3);
353353
ValueRef weight = prepack_standard(
354-
graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked);
354+
graph,
355+
weight_data,
356+
graph.storage_type_of(out),
357+
utils::kWidthPacked,
358+
/*passthrough = */ true);
355359
ValueRef mat2_is_transposed = graph.add_scalar(true);
356360

357361
if (graph.val_is_none(bias)) {

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,47 @@ void quantize_per_channel_impl(
480480
graph, input, scale, zero_point, axis, quant_min, quant_max, output);
481481
}
482482

483+
void quantize_affine_impl(
484+
ComputeGraph& graph,
485+
const std::vector<ValueRef>& args) {
486+
int arg_idx = 0;
487+
const ValueRef input = args[arg_idx++];
488+
const ValueRef block_size =
489+
args[arg_idx++]; // SymInt[] - ignored for per-tensor
490+
const ValueRef scale = args[arg_idx++];
491+
const ValueRef zero_point = args[arg_idx++];
492+
const ValueRef output_dtype = args[arg_idx++];
493+
const ValueRef quant_min = args[arg_idx++];
494+
const ValueRef quant_max = args[arg_idx++];
495+
const ValueRef output = args[arg_idx++];
496+
497+
// Suppress unused variable warnings
498+
(void)output_dtype;
499+
500+
// Check tensor types
501+
VK_CHECK_COND(graph.val_is_tensor(input));
502+
VK_CHECK_COND(graph.val_is_tensor(output));
503+
504+
// Verify input is a floating point type
505+
VK_CHECK_COND(
506+
graph.dtype_of(input) == vkapi::kDouble ||
507+
graph.dtype_of(input) == vkapi::kFloat ||
508+
graph.dtype_of(input) == vkapi::kHalf);
509+
510+
// Check if this is per-tensor quantization (only supported granularity)
511+
// block_size should equal input tensor dimensions for per-tensor quantization
512+
const auto input_sizes = graph.sizes_of(input);
513+
const auto block_size_list = graph.get_int_list(block_size);
514+
VK_CHECK_COND(block_size_list->size() == input_sizes.size());
515+
for (size_t i = 0; i < input_sizes.size(); i++) {
516+
VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]);
517+
}
518+
519+
// Default to per-tensor quantization for TorchAO affine ops
520+
add_quantize_per_tensor_node(
521+
graph, input, scale, zero_point, quant_min, quant_max, output);
522+
}
523+
483524
REGISTER_OPERATORS {
484525
VK_REGISTER_OP(
485526
quantized_decomposed.quantize_per_tensor.tensor,
@@ -489,6 +530,9 @@ REGISTER_OPERATORS {
489530
VK_REGISTER_OP(
490531
quantized_decomposed.quantize_per_channel.default,
491532
quantize_per_channel_impl);
533+
534+
// TorchAO affine quantization operators
535+
VK_REGISTER_OP(torchao.quantize_affine.default, quantize_affine_impl);
492536
}
493537

494538
} // namespace vkcompute

0 commit comments

Comments
 (0)