diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 1f77b30cda3..619b5f24b82 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -272,6 +272,60 @@ def register_quantization_op(features: OpFeatures): return features +@update_features( + [ + exir_ops.edge.torchao.quantize_affine.default, + exir_ops.edge.torchao.dequantize_affine.default, + exir_ops.edge.torchao.choose_qparams_affine.default, + ] +) +def register_torchao_quantization_op(features: OpFeatures): + # TorchAO quantization operators - default to per-tensor behavior + # Same features as standard quantization ops + features.texture_impl = TextureImplFeatures( + uses_axis_map=True, + valid_packed_dims={ + PackedDim.WIDTH, + }, + ) + features.buffer_impl = True + features.resize_fn = True + features.optimal_storage = VkStorageType.BUFFER + + def check_torchao_quantization_node(node: torch.fx.Node) -> bool: + # Only per-tensor quantization is supported by the Vulkan backend. + if len(node.args) < 2: + return False + + block_size = node.args[1] + + if not isinstance(block_size, (list, tuple)): + return False + + input_arg = node.args[0] + if not isinstance(input_arg, torch.fx.Node): + return False + + input_tensor = input_arg.meta.get("val", None) + if not isinstance(input_tensor, FakeTensor): + return False + + input_shape = list(input_tensor.shape) + + if len(block_size) != len(input_shape): + return False + + # Check if block_size matches input_shape exactly (per-tensor quantization) + for i in range(len(block_size)): + if block_size[i] != input_shape[i]: + return False + + return True + + features.check_node_fn = check_torchao_quantization_node + return features + + @update_features( [ exir_ops.edge.aten.add.Tensor, diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 5e9599b91e6..de269920eea 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -306,10 +306,12 @@ void choose_qparams_tensor_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - only accept Vulkan-supported types - // The Vulkan backend only supports float32 and int32, not float64/int64 + // Verify output types - accept both int32 and float32 for zero_point + // TorchAO may use float32 for zero_point in some cases VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kFloat); // Check that texture storage is width packed if (!graph.is_buffer_storage(input)) { @@ -352,21 +354,96 @@ void choose_qparams_per_token_asymmetric_impl( graph.dtype_of(input) == vkapi::kHalf || graph.dtype_of(input) == vkapi::kDouble); - // Verify output types - only accept Vulkan-supported types - // The Vulkan backend only supports float32 and int32, not float64/int64 + // Verify output types - accept both int32 and float32 for zero_point + // TorchAO may use float32 for zero_point in some cases VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); - VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kFloat); add_choose_qparams_per_token_asymmetric_node( graph, input, scale_out, zero_point_out); } +void choose_qparams_affine_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef mapping_type = args[arg_idx++]; // str - ignored for per-tensor + const ValueRef block_size = + args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef target_dtype = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef eps = args[arg_idx++]; + const ValueRef scale_dtype = args[arg_idx++]; + const ValueRef zero_point_dtype = args[arg_idx++]; + const ValueRef out_tuple_ref = args[arg_idx++]; + + // Suppress unused variable warnings + (void)mapping_type; + (void)target_dtype; + (void)scale_dtype; + (void)zero_point_dtype; + + ValueRef scale_out = kDummyValueRef; + ValueRef zero_point_out = kDummyValueRef; + + { + const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref); + scale_out = out_tuple->at(0); + zero_point_out = out_tuple->at(1); + } + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(scale_out)); + VK_CHECK_COND(graph.val_is_tensor(zero_point_out)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf || + graph.dtype_of(input) == vkapi::kDouble); + + // Verify output types - accept both int32 and float32 for zero_point + // TorchAO may use float32 for zero_point in some cases + VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat); + VK_CHECK_COND( + graph.dtype_of(zero_point_out) == vkapi::kInt || + graph.dtype_of(zero_point_out) == vkapi::kFloat); + + // Check if this is per-tensor quantization (only supported granularity) + // block_size should equal input tensor dimensions for per-tensor quantization + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + VK_CHECK_COND(block_size_list->size() == input_sizes.size()); + for (size_t i = 0; i < input_sizes.size(); i++) { + VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); + } + + // Check that texture storage is width packed + if (!graph.is_buffer_storage(input)) { + VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim); + } + + // Default to per-tensor quantization parameter calculation for TorchAO affine + // ops + add_choose_qparams_tensor_node( + graph, input, quant_min, quant_max, eps, scale_out, zero_point_out); +} + REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl); VK_REGISTER_OP( quantized_decomposed.choose_qparams_per_token_asymmetric.default, choose_qparams_per_token_asymmetric_impl); + + // TorchAO affine choose_qparams operators + VK_REGISTER_OP( + torchao.choose_qparams_affine.default, choose_qparams_affine_impl); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp index 1578b515f55..7edb9b2f70d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp @@ -508,6 +508,56 @@ void dequantize_per_channel_impl( graph, input, scale, zero_point, axis, quant_min, quant_max, output); } +void dequantize_affine_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef block_size = + args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef input_dtype = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Suppress unused variable warnings + (void)input_dtype; + (void)output_dtype; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is an integer type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kByte || + graph.dtype_of(input) == vkapi::kChar || + graph.dtype_of(input) == vkapi::kShort || + graph.dtype_of(input) == vkapi::kInt); + + // Verify output is a floating point type + VK_CHECK_COND( + graph.dtype_of(output) == vkapi::kHalf || + graph.dtype_of(output) == vkapi::kFloat || + graph.dtype_of(output) == vkapi::kDouble); + + // Check if this is per-tensor quantization (only supported granularity) + // block_size should equal input tensor dimensions for per-tensor quantization + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + VK_CHECK_COND(block_size_list->size() == input_sizes.size()); + for (size_t i = 0; i < input_sizes.size(); i++) { + VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); + } + + // Default to per-tensor dequantization for TorchAO affine ops + add_dequantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.dequantize_per_tensor.tensor, @@ -518,6 +568,9 @@ REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.dequantize_per_channel.default, dequantize_per_channel_impl); + + // TorchAO affine dequantization operators + VK_REGISTER_OP(torchao.dequantize_affine.default, dequantize_affine_impl); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 86df735acbe..14ed9c84a32 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -351,7 +351,11 @@ void linear(ComputeGraph& graph, const std::vector& args) { ValueRef bias = args.at(2); ValueRef out = args.at(3); ValueRef weight = prepack_standard( - graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked); + graph, + weight_data, + graph.storage_type_of(out), + utils::kWidthPacked, + /*passthrough = */ true); ValueRef mat2_is_transposed = graph.add_scalar(true); if (graph.val_is_none(bias)) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp index 0105a384042..d786981e1fc 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Quantize.cpp @@ -480,6 +480,47 @@ void quantize_per_channel_impl( graph, input, scale, zero_point, axis, quant_min, quant_max, output); } +void quantize_affine_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef input = args[arg_idx++]; + const ValueRef block_size = + args[arg_idx++]; // SymInt[] - ignored for per-tensor + const ValueRef scale = args[arg_idx++]; + const ValueRef zero_point = args[arg_idx++]; + const ValueRef output_dtype = args[arg_idx++]; + const ValueRef quant_min = args[arg_idx++]; + const ValueRef quant_max = args[arg_idx++]; + const ValueRef output = args[arg_idx++]; + + // Suppress unused variable warnings + (void)output_dtype; + + // Check tensor types + VK_CHECK_COND(graph.val_is_tensor(input)); + VK_CHECK_COND(graph.val_is_tensor(output)); + + // Verify input is a floating point type + VK_CHECK_COND( + graph.dtype_of(input) == vkapi::kDouble || + graph.dtype_of(input) == vkapi::kFloat || + graph.dtype_of(input) == vkapi::kHalf); + + // Check if this is per-tensor quantization (only supported granularity) + // block_size should equal input tensor dimensions for per-tensor quantization + const auto input_sizes = graph.sizes_of(input); + const auto block_size_list = graph.get_int_list(block_size); + VK_CHECK_COND(block_size_list->size() == input_sizes.size()); + for (size_t i = 0; i < input_sizes.size(); i++) { + VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]); + } + + // Default to per-tensor quantization for TorchAO affine ops + add_quantize_per_tensor_node( + graph, input, scale, zero_point, quant_min, quant_max, output); +} + REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.quantize_per_tensor.tensor, @@ -489,6 +530,9 @@ REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.quantize_per_channel.default, quantize_per_channel_impl); + + // TorchAO affine quantization operators + VK_REGISTER_OP(torchao.quantize_affine.default, quantize_affine_impl); } } // namespace vkcompute