Skip to content

[ET-VK][Ops] affine quantization operators registration #12369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,60 @@ def register_quantization_op(features: OpFeatures):
return features


@update_features(
[
exir_ops.edge.torchao.quantize_affine.default,
exir_ops.edge.torchao.dequantize_affine.default,
exir_ops.edge.torchao.choose_qparams_affine.default,
]
)
def register_torchao_quantization_op(features: OpFeatures):
# TorchAO quantization operators - default to per-tensor behavior
# Same features as standard quantization ops
features.texture_impl = TextureImplFeatures(
uses_axis_map=True,
valid_packed_dims={
PackedDim.WIDTH,
},
)
features.buffer_impl = True
features.resize_fn = True
features.optimal_storage = VkStorageType.BUFFER

def check_torchao_quantization_node(node: torch.fx.Node) -> bool:
# Only per-tensor quantization is supported by the Vulkan backend.
if len(node.args) < 2:
return False

block_size = node.args[1]

if not isinstance(block_size, (list, tuple)):
return False

input_arg = node.args[0]
if not isinstance(input_arg, torch.fx.Node):
return False

input_tensor = input_arg.meta.get("val", None)
if not isinstance(input_tensor, FakeTensor):
return False

input_shape = list(input_tensor.shape)

if len(block_size) != len(input_shape):
return False

# Check if block_size matches input_shape exactly (per-tensor quantization)
for i in range(len(block_size)):
if block_size[i] != input_shape[i]:
return False

return True

features.check_node_fn = check_torchao_quantization_node
return features


@update_features(
[
exir_ops.edge.aten.add.Tensor,
Expand Down
89 changes: 83 additions & 6 deletions backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,10 +306,12 @@ void choose_qparams_tensor_impl(
graph.dtype_of(input) == vkapi::kHalf ||
graph.dtype_of(input) == vkapi::kDouble);

// Verify output types - only accept Vulkan-supported types
// The Vulkan backend only supports float32 and int32, not float64/int64
// Verify output types - accept both int32 and float32 for zero_point
// TorchAO may use float32 for zero_point in some cases
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt);
VK_CHECK_COND(
graph.dtype_of(zero_point_out) == vkapi::kInt ||
graph.dtype_of(zero_point_out) == vkapi::kFloat);

// Check that texture storage is width packed
if (!graph.is_buffer_storage(input)) {
Expand Down Expand Up @@ -352,21 +354,96 @@ void choose_qparams_per_token_asymmetric_impl(
graph.dtype_of(input) == vkapi::kHalf ||
graph.dtype_of(input) == vkapi::kDouble);

// Verify output types - only accept Vulkan-supported types
// The Vulkan backend only supports float32 and int32, not float64/int64
// Verify output types - accept both int32 and float32 for zero_point
// TorchAO may use float32 for zero_point in some cases
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt);
VK_CHECK_COND(
graph.dtype_of(zero_point_out) == vkapi::kInt ||
graph.dtype_of(zero_point_out) == vkapi::kFloat);

add_choose_qparams_per_token_asymmetric_node(
graph, input, scale_out, zero_point_out);
}

void choose_qparams_affine_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef input = args[arg_idx++];
const ValueRef mapping_type = args[arg_idx++]; // str - ignored for per-tensor
const ValueRef block_size =
args[arg_idx++]; // SymInt[] - ignored for per-tensor
const ValueRef target_dtype = args[arg_idx++];
const ValueRef quant_min = args[arg_idx++];
const ValueRef quant_max = args[arg_idx++];
const ValueRef eps = args[arg_idx++];
const ValueRef scale_dtype = args[arg_idx++];
const ValueRef zero_point_dtype = args[arg_idx++];
const ValueRef out_tuple_ref = args[arg_idx++];

// Suppress unused variable warnings
(void)mapping_type;
(void)target_dtype;
(void)scale_dtype;
(void)zero_point_dtype;

ValueRef scale_out = kDummyValueRef;
ValueRef zero_point_out = kDummyValueRef;

{
const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref);
scale_out = out_tuple->at(0);
zero_point_out = out_tuple->at(1);
}

// Check tensor types
VK_CHECK_COND(graph.val_is_tensor(input));
VK_CHECK_COND(graph.val_is_tensor(scale_out));
VK_CHECK_COND(graph.val_is_tensor(zero_point_out));

// Verify input is a floating point type
VK_CHECK_COND(
graph.dtype_of(input) == vkapi::kFloat ||
graph.dtype_of(input) == vkapi::kHalf ||
graph.dtype_of(input) == vkapi::kDouble);

// Verify output types - accept both int32 and float32 for zero_point
// TorchAO may use float32 for zero_point in some cases
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
VK_CHECK_COND(
graph.dtype_of(zero_point_out) == vkapi::kInt ||
graph.dtype_of(zero_point_out) == vkapi::kFloat);

// Check if this is per-tensor quantization (only supported granularity)
// block_size should equal input tensor dimensions for per-tensor quantization
const auto input_sizes = graph.sizes_of(input);
const auto block_size_list = graph.get_int_list(block_size);
VK_CHECK_COND(block_size_list->size() == input_sizes.size());
for (size_t i = 0; i < input_sizes.size(); i++) {
VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]);
}

// Check that texture storage is width packed
if (!graph.is_buffer_storage(input)) {
VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim);
}

// Default to per-tensor quantization parameter calculation for TorchAO affine
// ops
add_choose_qparams_tensor_node(
graph, input, quant_min, quant_max, eps, scale_out, zero_point_out);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(
quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl);
VK_REGISTER_OP(
quantized_decomposed.choose_qparams_per_token_asymmetric.default,
choose_qparams_per_token_asymmetric_impl);

// TorchAO affine choose_qparams operators
VK_REGISTER_OP(
torchao.choose_qparams_affine.default, choose_qparams_affine_impl);
}

} // namespace vkcompute
53 changes: 53 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,56 @@ void dequantize_per_channel_impl(
graph, input, scale, zero_point, axis, quant_min, quant_max, output);
}

void dequantize_affine_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef input = args[arg_idx++];
const ValueRef block_size =
args[arg_idx++]; // SymInt[] - ignored for per-tensor
const ValueRef scale = args[arg_idx++];
const ValueRef zero_point = args[arg_idx++];
const ValueRef input_dtype = args[arg_idx++];
const ValueRef quant_min = args[arg_idx++];
const ValueRef quant_max = args[arg_idx++];
const ValueRef output_dtype = args[arg_idx++];
const ValueRef output = args[arg_idx++];

// Suppress unused variable warnings
(void)input_dtype;
(void)output_dtype;

// Check tensor types
VK_CHECK_COND(graph.val_is_tensor(input));
VK_CHECK_COND(graph.val_is_tensor(output));

// Verify input is an integer type
VK_CHECK_COND(
graph.dtype_of(input) == vkapi::kByte ||
graph.dtype_of(input) == vkapi::kChar ||
graph.dtype_of(input) == vkapi::kShort ||
graph.dtype_of(input) == vkapi::kInt);

// Verify output is a floating point type
VK_CHECK_COND(
graph.dtype_of(output) == vkapi::kHalf ||
graph.dtype_of(output) == vkapi::kFloat ||
graph.dtype_of(output) == vkapi::kDouble);

// Check if this is per-tensor quantization (only supported granularity)
// block_size should equal input tensor dimensions for per-tensor quantization
const auto input_sizes = graph.sizes_of(input);
const auto block_size_list = graph.get_int_list(block_size);
VK_CHECK_COND(block_size_list->size() == input_sizes.size());
for (size_t i = 0; i < input_sizes.size(); i++) {
VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]);
}

// Default to per-tensor dequantization for TorchAO affine ops
add_dequantize_per_tensor_node(
graph, input, scale, zero_point, quant_min, quant_max, output);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(
quantized_decomposed.dequantize_per_tensor.tensor,
Expand All @@ -518,6 +568,9 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(
quantized_decomposed.dequantize_per_channel.default,
dequantize_per_channel_impl);

// TorchAO affine dequantization operators
VK_REGISTER_OP(torchao.dequantize_affine.default, dequantize_affine_impl);
}

} // namespace vkcompute
6 changes: 5 additions & 1 deletion backends/vulkan/runtime/graph/ops/impl/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,11 @@ void linear(ComputeGraph& graph, const std::vector<ValueRef>& args) {
ValueRef bias = args.at(2);
ValueRef out = args.at(3);
ValueRef weight = prepack_standard(
graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked);
graph,
weight_data,
graph.storage_type_of(out),
utils::kWidthPacked,
/*passthrough = */ true);
ValueRef mat2_is_transposed = graph.add_scalar(true);

if (graph.val_is_none(bias)) {
Expand Down
44 changes: 44 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/Quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,47 @@ void quantize_per_channel_impl(
graph, input, scale, zero_point, axis, quant_min, quant_max, output);
}

void quantize_affine_impl(
ComputeGraph& graph,
const std::vector<ValueRef>& args) {
int arg_idx = 0;
const ValueRef input = args[arg_idx++];
const ValueRef block_size =
args[arg_idx++]; // SymInt[] - ignored for per-tensor
const ValueRef scale = args[arg_idx++];
const ValueRef zero_point = args[arg_idx++];
const ValueRef output_dtype = args[arg_idx++];
const ValueRef quant_min = args[arg_idx++];
const ValueRef quant_max = args[arg_idx++];
const ValueRef output = args[arg_idx++];

// Suppress unused variable warnings
(void)output_dtype;

// Check tensor types
VK_CHECK_COND(graph.val_is_tensor(input));
VK_CHECK_COND(graph.val_is_tensor(output));

// Verify input is a floating point type
VK_CHECK_COND(
graph.dtype_of(input) == vkapi::kDouble ||
graph.dtype_of(input) == vkapi::kFloat ||
graph.dtype_of(input) == vkapi::kHalf);

// Check if this is per-tensor quantization (only supported granularity)
// block_size should equal input tensor dimensions for per-tensor quantization
const auto input_sizes = graph.sizes_of(input);
const auto block_size_list = graph.get_int_list(block_size);
VK_CHECK_COND(block_size_list->size() == input_sizes.size());
for (size_t i = 0; i < input_sizes.size(); i++) {
VK_CHECK_COND((*block_size_list)[i] == input_sizes[i]);
}

// Default to per-tensor quantization for TorchAO affine ops
add_quantize_per_tensor_node(
graph, input, scale, zero_point, quant_min, quant_max, output);
}

REGISTER_OPERATORS {
VK_REGISTER_OP(
quantized_decomposed.quantize_per_tensor.tensor,
Expand All @@ -489,6 +530,9 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(
quantized_decomposed.quantize_per_channel.default,
quantize_per_channel_impl);

// TorchAO affine quantization operators
VK_REGISTER_OP(torchao.quantize_affine.default, quantize_affine_impl);
}

} // namespace vkcompute
Loading