Skip to content

Commit 04c1a39

Browse files
author
morelos
committed
[ET-VK][Ops] affine quantization operators registration
# Context In order to enable dynamic quantization, especially for the source transform method using `Int8DynActInt4WeightQuantizer` we need to have vulkan versions for `quantize_affine`, `dequantize_affine`, and `choose_qparams_affine`. Currently we do not have a shader that performs block-based quantization as expected from these shaders, so we delegate to the per_tensor variant just to get unblocked. At a later stage, this will likely be developed more on in order to ensure we don't get too much accuracy loss. # Changes This creates a schema reference in the TorchAO library for out variants of these respective operators. Then there is a VK_REGISTER_OP done on them to ensure that we can properly register them when lowering the ET model with vulkan. Also the vulkan_quantizer is changed a bit in this to enable a dynamic quantization config so that we aren't purely working with just static quantization anymore. Furthermore, we have `_annotate_for_static_quantization_config` for parity/legacy reasons, and we simply create an equivalent dynamic quantization config method. We also changed `Linear.cpp`, particularly to allow a passthrough for weight_data since during dynamic quantization it's possible that it'll be a tensor_data than tensor_ref. Differential Revision: [D78035354](https://our.internmc.facebook.com/intern/diff/D78035354/) ghstack-source-id: 295489089 Pull Request resolved: #12369
1 parent a50dc92 commit 04c1a39

File tree

6 files changed

+226
-12
lines changed

6 files changed

+226
-12
lines changed

backends/vulkan/op_registry.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,28 @@ def register_quantization_op(features: OpFeatures):
272272
return features
273273

274274

275+
@update_features(
276+
[
277+
exir_ops.edge.torchao.quantize_affine.default,
278+
exir_ops.edge.torchao.dequantize_affine.default,
279+
exir_ops.edge.torchao.choose_qparams_affine.default,
280+
]
281+
)
282+
def register_torchao_quantization_op(features: OpFeatures):
283+
# TorchAO quantization operators - default to per-tensor behavior
284+
# Same features as standard quantization ops
285+
features.texture_impl = TextureImplFeatures(
286+
uses_axis_map=True,
287+
valid_packed_dims={
288+
PackedDim.WIDTH,
289+
},
290+
)
291+
features.buffer_impl = True
292+
features.resize_fn = True
293+
features.optimal_storage = VkStorageType.BUFFER
294+
return features
295+
296+
275297
@update_features(
276298
[
277299
exir_ops.edge.aten.add.Tensor,

backends/vulkan/quantizer/vulkan_quantizer.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
propagate_annotation,
1919
)
2020
from torch.fx import Node
21-
from torchao.quantization.pt2e import PerChannelMinMaxObserver
21+
from torchao.quantization.pt2e import (
22+
PerChannelMinMaxObserver,
23+
PlaceholderObserver,
24+
)
2225
from torchao.quantization.pt2e.quantizer import (
2326
QuantizationConfig,
2427
QuantizationSpec,
@@ -77,6 +80,38 @@ def get_linear_weight_only_qcs_xnn_qconfig(quant_bits: int) -> QuantizationConfi
7780
)
7881

7982

83+
@functools.lru_cache
84+
def get_dynamic_activation_qconfig(
85+
weight_bits: int = 4,
86+
act_qmin: int = -128,
87+
act_qmax: int = 127,
88+
) -> QuantizationConfig:
89+
"""
90+
Return a QuantizationConfig for dynamic activation quantization with 4-bit weights.
91+
This is compatible with Vulkan backend's quantized_decomposed operators.
92+
"""
93+
# Dynamic activation quantization spec
94+
act_quantization_spec = QuantizationSpec(
95+
dtype=torch.int8,
96+
quant_min=act_qmin,
97+
quant_max=act_qmax,
98+
qscheme=torch.per_tensor_affine,
99+
is_dynamic=True,
100+
observer_or_fake_quant_ctr=PlaceholderObserver,
101+
)
102+
103+
# Weight quantization spec (per-channel symmetric)
104+
weight_qspec = get_linear_weight_qcs_qspec(weight_bits)
105+
106+
return QuantizationConfig(
107+
input_activation=act_quantization_spec,
108+
output_activation=None,
109+
weight=weight_qspec,
110+
bias=None,
111+
is_qat=False,
112+
)
113+
114+
80115
_SUPPORTED_OPS = [
81116
"linear",
82117
]
@@ -99,12 +134,15 @@ def transform_for_annotation(
99134
return _convert_scalars_to_attrs(model)
100135

101136
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
102-
# currently only support static quant on Vulkan
103-
model = self._annotate_for_static_quantization_config(model)
137+
# Support both static and dynamic quantization
138+
if self.global_config and self.global_config.input_activation and self.global_config.input_activation.is_dynamic:
139+
model = self._annotate_for_dynamic_quantization_config(model)
140+
else:
141+
model = self._annotate_for_static_quantization_config(model)
104142
propagate_annotation(model)
105143
return model
106144

107-
def _annotate_all_static_patterns(
145+
def _annotate_all_patterns(
108146
self,
109147
model: torch.fx.GraphModule,
110148
quantization_config: Optional[QuantizationConfig],
@@ -120,7 +158,16 @@ def _annotate_all_static_patterns(
120158
def _annotate_for_static_quantization_config(
121159
self, model: torch.fx.GraphModule
122160
) -> torch.fx.GraphModule:
123-
self._annotate_all_static_patterns(
161+
self._annotate_all_patterns(
162+
model,
163+
self.global_config,
164+
)
165+
return model
166+
167+
def _annotate_for_dynamic_quantization_config(
168+
self, model: torch.fx.GraphModule
169+
) -> torch.fx.GraphModule:
170+
self._annotate_all_patterns(
124171
model,
125172
self.global_config,
126173
)

backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,10 +306,12 @@ void choose_qparams_tensor_impl(
306306
graph.dtype_of(input) == vkapi::kHalf ||
307307
graph.dtype_of(input) == vkapi::kDouble);
308308

309-
// Verify output types - only accept Vulkan-supported types
310-
// The Vulkan backend only supports float32 and int32, not float64/int64
309+
// Verify output types - accept both int32 and float32 for zero_point
310+
// TorchAO may use float32 for zero_point in some cases
311311
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
312-
VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt);
312+
VK_CHECK_COND(
313+
graph.dtype_of(zero_point_out) == vkapi::kInt ||
314+
graph.dtype_of(zero_point_out) == vkapi::kFloat);
313315

314316
// Check that texture storage is width packed
315317
if (!graph.is_buffer_storage(input)) {
@@ -352,21 +354,85 @@ void choose_qparams_per_token_asymmetric_impl(
352354
graph.dtype_of(input) == vkapi::kHalf ||
353355
graph.dtype_of(input) == vkapi::kDouble);
354356

355-
// Verify output types - only accept Vulkan-supported types
356-
// The Vulkan backend only supports float32 and int32, not float64/int64
357+
// Verify output types - accept both int32 and float32 for zero_point
358+
// TorchAO may use float32 for zero_point in some cases
357359
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
358-
VK_CHECK_COND(graph.dtype_of(zero_point_out) == vkapi::kInt);
360+
VK_CHECK_COND(
361+
graph.dtype_of(zero_point_out) == vkapi::kInt ||
362+
graph.dtype_of(zero_point_out) == vkapi::kFloat);
359363

360364
add_choose_qparams_per_token_asymmetric_node(
361365
graph, input, scale_out, zero_point_out);
362366
}
363367

368+
void choose_qparams_affine_impl(
369+
ComputeGraph& graph,
370+
const std::vector<ValueRef>& args) {
371+
int arg_idx = 0;
372+
const ValueRef input = args[arg_idx++];
373+
const ValueRef mapping_type = args[arg_idx++]; // str - ignored for per-tensor
374+
const ValueRef block_size = args[arg_idx++]; // SymInt[] - ignored for per-tensor
375+
const ValueRef target_dtype = args[arg_idx++];
376+
const ValueRef quant_min = args[arg_idx++];
377+
const ValueRef quant_max = args[arg_idx++];
378+
const ValueRef eps = args[arg_idx++];
379+
const ValueRef scale_dtype = args[arg_idx++];
380+
const ValueRef zero_point_dtype = args[arg_idx++];
381+
const ValueRef out_tuple_ref = args[arg_idx++];
382+
383+
// Suppress unused variable warnings
384+
(void)mapping_type;
385+
(void)block_size;
386+
(void)target_dtype;
387+
(void)scale_dtype;
388+
(void)zero_point_dtype;
389+
390+
ValueRef scale_out = kDummyValueRef;
391+
ValueRef zero_point_out = kDummyValueRef;
392+
393+
{
394+
const ValueListPtr out_tuple = graph.get_value_list(out_tuple_ref);
395+
scale_out = out_tuple->at(0);
396+
zero_point_out = out_tuple->at(1);
397+
}
398+
399+
// Check tensor types
400+
VK_CHECK_COND(graph.val_is_tensor(input));
401+
VK_CHECK_COND(graph.val_is_tensor(scale_out));
402+
VK_CHECK_COND(graph.val_is_tensor(zero_point_out));
403+
404+
// Verify input is a floating point type
405+
VK_CHECK_COND(
406+
graph.dtype_of(input) == vkapi::kFloat ||
407+
graph.dtype_of(input) == vkapi::kHalf ||
408+
graph.dtype_of(input) == vkapi::kDouble);
409+
410+
// Verify output types - accept both int32 and float32 for zero_point
411+
// TorchAO may use float32 for zero_point in some cases
412+
VK_CHECK_COND(graph.dtype_of(scale_out) == vkapi::kFloat);
413+
VK_CHECK_COND(
414+
graph.dtype_of(zero_point_out) == vkapi::kInt ||
415+
graph.dtype_of(zero_point_out) == vkapi::kFloat);
416+
417+
// Check that texture storage is width packed
418+
if (!graph.is_buffer_storage(input)) {
419+
VK_CHECK_COND(graph.packed_dim_of(input) == WHCN::kWidthDim);
420+
}
421+
422+
// Default to per-tensor quantization parameter calculation for TorchAO affine ops
423+
add_choose_qparams_tensor_node(
424+
graph, input, quant_min, quant_max, eps, scale_out, zero_point_out);
425+
}
426+
364427
REGISTER_OPERATORS {
365428
VK_REGISTER_OP(
366429
quantized_decomposed.choose_qparams.tensor, choose_qparams_tensor_impl);
367430
VK_REGISTER_OP(
368431
quantized_decomposed.choose_qparams_per_token_asymmetric.default,
369432
choose_qparams_per_token_asymmetric_impl);
433+
434+
// TorchAO affine choose_qparams operators
435+
VK_REGISTER_OP(torchao.choose_qparams_affine.default, choose_qparams_affine_impl);
370436
}
371437

372438
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Dequantize.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,47 @@ void dequantize_per_channel_impl(
508508
graph, input, scale, zero_point, axis, quant_min, quant_max, output);
509509
}
510510

511+
void dequantize_affine_impl(
512+
ComputeGraph& graph,
513+
const std::vector<ValueRef>& args) {
514+
int arg_idx = 0;
515+
const ValueRef input = args[arg_idx++];
516+
const ValueRef block_size = args[arg_idx++]; // SymInt[] - ignored for per-tensor
517+
const ValueRef scale = args[arg_idx++];
518+
const ValueRef zero_point = args[arg_idx++];
519+
const ValueRef input_dtype = args[arg_idx++];
520+
const ValueRef quant_min = args[arg_idx++];
521+
const ValueRef quant_max = args[arg_idx++];
522+
const ValueRef output_dtype = args[arg_idx++];
523+
const ValueRef output = args[arg_idx++];
524+
525+
// Suppress unused variable warnings
526+
(void)block_size;
527+
(void)input_dtype;
528+
(void)output_dtype;
529+
530+
// Check tensor types
531+
VK_CHECK_COND(graph.val_is_tensor(input));
532+
VK_CHECK_COND(graph.val_is_tensor(output));
533+
534+
// Verify input is an integer type
535+
VK_CHECK_COND(
536+
graph.dtype_of(input) == vkapi::kByte ||
537+
graph.dtype_of(input) == vkapi::kChar ||
538+
graph.dtype_of(input) == vkapi::kShort ||
539+
graph.dtype_of(input) == vkapi::kInt);
540+
541+
// Verify output is a floating point type
542+
VK_CHECK_COND(
543+
graph.dtype_of(output) == vkapi::kHalf ||
544+
graph.dtype_of(output) == vkapi::kFloat ||
545+
graph.dtype_of(output) == vkapi::kDouble);
546+
547+
// Default to per-tensor dequantization for TorchAO affine ops
548+
add_dequantize_per_tensor_node(
549+
graph, input, scale, zero_point, quant_min, quant_max, output);
550+
}
551+
511552
REGISTER_OPERATORS {
512553
VK_REGISTER_OP(
513554
quantized_decomposed.dequantize_per_tensor.tensor,
@@ -518,6 +559,9 @@ REGISTER_OPERATORS {
518559
VK_REGISTER_OP(
519560
quantized_decomposed.dequantize_per_channel.default,
520561
dequantize_per_channel_impl);
562+
563+
// TorchAO affine dequantization operators
564+
VK_REGISTER_OP(torchao.dequantize_affine.default, dequantize_affine_impl);
521565
}
522566

523567
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ void linear(ComputeGraph& graph, const std::vector<ValueRef>& args) {
351351
ValueRef bias = args.at(2);
352352
ValueRef out = args.at(3);
353353
ValueRef weight = prepack_standard(
354-
graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked);
354+
graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked, /*passthrough = */ true);
355355
ValueRef mat2_is_transposed = graph.add_scalar(true);
356356

357357
if (graph.val_is_none(bias)) {

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,38 @@ void quantize_per_channel_impl(
480480
graph, input, scale, zero_point, axis, quant_min, quant_max, output);
481481
}
482482

483+
void quantize_affine_impl(
484+
ComputeGraph& graph,
485+
const std::vector<ValueRef>& args) {
486+
int arg_idx = 0;
487+
const ValueRef input = args[arg_idx++];
488+
const ValueRef block_size = args[arg_idx++]; // SymInt[] - ignored for per-tensor
489+
const ValueRef scale = args[arg_idx++];
490+
const ValueRef zero_point = args[arg_idx++];
491+
const ValueRef output_dtype = args[arg_idx++];
492+
const ValueRef quant_min = args[arg_idx++];
493+
const ValueRef quant_max = args[arg_idx++];
494+
const ValueRef output = args[arg_idx++];
495+
496+
// Suppress unused variable warnings
497+
(void)block_size;
498+
(void)output_dtype;
499+
500+
// Check tensor types
501+
VK_CHECK_COND(graph.val_is_tensor(input));
502+
VK_CHECK_COND(graph.val_is_tensor(output));
503+
504+
// Verify input is a floating point type
505+
VK_CHECK_COND(
506+
graph.dtype_of(input) == vkapi::kDouble ||
507+
graph.dtype_of(input) == vkapi::kFloat ||
508+
graph.dtype_of(input) == vkapi::kHalf);
509+
510+
// Default to per-tensor quantization for TorchAO affine ops
511+
add_quantize_per_tensor_node(
512+
graph, input, scale, zero_point, quant_min, quant_max, output);
513+
}
514+
483515
REGISTER_OPERATORS {
484516
VK_REGISTER_OP(
485517
quantized_decomposed.quantize_per_tensor.tensor,
@@ -489,6 +521,9 @@ REGISTER_OPERATORS {
489521
VK_REGISTER_OP(
490522
quantized_decomposed.quantize_per_channel.default,
491523
quantize_per_channel_impl);
524+
525+
// TorchAO affine quantization operators
526+
VK_REGISTER_OP(torchao.quantize_affine.default, quantize_affine_impl);
492527
}
493528

494529
} // namespace vkcompute

0 commit comments

Comments
 (0)