Skip to content

Commit 9b21e73

Browse files
RahulC7facebook-github-bot
authored andcommitted
Using generic implemntation for 16-bit activations and 8 bit weights for Conv2D in Backends (pytorch#16007)
Summary: # Context We continue from D84284794 to add support for 16-bit activations. Note that right now, all though they support 16-bit activations already, it's only if the weights are also 16-bits. To do this, we need to change the way we template some functions. # Current Behavior Right now, we're composing two macros together, the `ET_FORALL_JARVIS_QUANTIZED_TYPES_WITH_INT16` macro: https://www.internalfb.com/code/fbsource/[9e8c6d8466107f58aa3de1b9e4ec71c49d670a8f]/fbcode/on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h?lines=22-25 and the function macro(`quantized_linear` chosen for example): https://www.internalfb.com/code/fbsource/[9e8c6d8466107f58aa3de1b9e4ec71c49d670a8f]/fbcode/on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/quantized_linear_out.cpp?lines=30-41 so together, it just becomes a switch statement, calling the `quantized_linear` function with the correct template parameter. However, note that it assumes that both the input activations and weights are the same dtype, which is not the case. # This Diff We finish by using the generic implementation for all the backends and adding e2e tests as well as unit tests. Differential Revision: D87993325
1 parent e0e957b commit 9b21e73

File tree

5 files changed

+340
-3
lines changed

5 files changed

+340
-3
lines changed

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,17 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
372372
# Add 16-bit quantizers for LinearPattern
373373
quantizers.append(CadenceAtenQuantizer(LinearPattern(), qconfig_A16))
374374
super().__init__(quantizers)
375+
376+
377+
class CadenceWith16BitConvActivationsQuantizer(CadenceQuantizer):
378+
"""
379+
Quantizer including A16 conv
380+
"""
381+
382+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
383+
if quantizers is None:
384+
quantizers = []
385+
# Add 16-bit quantizers for Conv patterns
386+
quantizers.append(CadenceAtenQuantizer(Conv1dPattern(), qconfig_A16))
387+
quantizers.append(CadenceAtenQuantizer(Conv2dPattern(), qconfig_A16))
388+
super().__init__(quantizers)

backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1010
#include <executorch/backends/cadence/hifi/operators/operators.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/op_quantized_conv2d.h>
1213

1314
#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1)))
1415

@@ -532,6 +533,30 @@ void quantized_conv2d_nchw_out(
532533
__ET_UNUSED const Tensor& out_multiplier,
533534
__ET_UNUSED const Tensor& out_shift,
534535
Tensor& out) {
536+
// Handle W8A16 heterogeneous type (int16_t activations, int8_t weights)
537+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
538+
input.scalar_type() == ::executorch::aten::ScalarType::Short &&
539+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
540+
::impl::generic::native::quantized_conv2d_nchw_out(
541+
ctx,
542+
input,
543+
weight,
544+
bias,
545+
stride,
546+
padding,
547+
dilation,
548+
groups,
549+
in_zero_point,
550+
weight_zero_point,
551+
bias_scale,
552+
output_scale,
553+
output_zero_point,
554+
out_multiplier,
555+
out_shift,
556+
out);
557+
return;
558+
}
559+
535560
const float bias_scale_float = bias_scale.const_data_ptr<float>()[0];
536561
const int32_t weight_zero_point_int =
537562
weight_zero_point.const_data_ptr<int32_t>()[0];
@@ -596,6 +621,30 @@ void quantized_conv2d_nchw_per_tensor_out(
596621
__ET_UNUSED int64_t out_multiplier,
597622
__ET_UNUSED int64_t out_shift,
598623
Tensor& out) {
624+
// Handle W8A16 heterogeneous type (int16_t activations, int8_t weights)
625+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
626+
input.scalar_type() == ::executorch::aten::ScalarType::Short &&
627+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
628+
::impl::generic::native::quantized_conv2d_nchw_per_tensor_out(
629+
ctx,
630+
input,
631+
weight,
632+
bias,
633+
stride,
634+
padding,
635+
dilation,
636+
groups,
637+
in_zero_point,
638+
weight_zero_point,
639+
bias_scale,
640+
output_scale,
641+
output_zero_point,
642+
out_multiplier,
643+
out_shift,
644+
out);
645+
return;
646+
}
647+
599648
bool optimized = 0;
600649

601650
if ((input.scalar_type() == ScalarType::Char) ||

backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1010
#include <executorch/backends/cadence/hifi/operators/operators.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/op_quantized_conv2d.h>
1213

1314
#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1)))
1415

@@ -438,6 +439,29 @@ void quantized_conv2d_nhwc_out(
438439
__ET_UNUSED const Tensor& out_multiplier,
439440
__ET_UNUSED const Tensor& out_shift,
440441
Tensor& out) {
442+
// Handle W8A16 heterogeneous type (int16_t activations, int8_t weights)
443+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
444+
input.scalar_type() == ::executorch::aten::ScalarType::Short &&
445+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
446+
::impl::generic::native::quantized_conv2d_nhwc_out(
447+
ctx,
448+
input,
449+
weight,
450+
bias,
451+
stride,
452+
padding,
453+
dilation,
454+
groups,
455+
in_zero_point,
456+
weight_zero_point,
457+
bias_scale,
458+
output_scale,
459+
output_zero_point,
460+
out_multiplier,
461+
out_shift,
462+
out);
463+
return;
464+
}
441465
const float bias_scale_float = bias_scale.const_data_ptr<float>()[0];
442466
const int32_t weight_zero_point_int =
443467
weight_zero_point.const_data_ptr<int32_t>()[0];
@@ -502,8 +526,31 @@ void quantized_conv2d_nhwc_per_tensor_out(
502526
__ET_UNUSED int64_t out_multiplier,
503527
__ET_UNUSED int64_t out_shift,
504528
Tensor& out) {
505-
bool optimized = 0;
529+
// Handle W8A16 heterogeneous type (int16_t activations, int8_t weights)
530+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
531+
input.scalar_type() == ::executorch::aten::ScalarType::Short &&
532+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
533+
::impl::generic::native::quantized_conv2d_nhwc_per_tensor_out(
534+
ctx,
535+
input,
536+
weight,
537+
bias,
538+
stride,
539+
padding,
540+
dilation,
541+
groups,
542+
in_zero_point,
543+
weight_zero_point,
544+
bias_scale,
545+
output_scale,
546+
output_zero_point,
547+
out_multiplier,
548+
out_shift,
549+
out);
550+
return;
551+
}
506552

553+
bool optimized = 0;
507554
if ((input.scalar_type() == ScalarType::Char) ||
508555
(input.scalar_type() == ScalarType::Byte))
509556
optimized = 1;

backends/cadence/hifi/operators/targets.bzl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ OPERATORS = [
6565
"ne",
6666
"permute_copy",
6767
"pow",
68-
"quantized_conv2d_nchw_out",
6968
"quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out",
7069
"quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out",
7170
"quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out",
@@ -74,7 +73,6 @@ OPERATORS = [
7473
"quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out",
7574
"quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out",
7675
"quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out",
77-
"quantized_conv2d_nhwc_out",
7876
"quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out",
7977
"quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out",
8078
"quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out",
@@ -125,3 +123,7 @@ def define_common_targets():
125123
# quantized_linear_out and quantized_linear_per_tensor_out needs additional dependency for int16 support
126124
define_operator("quantized_linear_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators/generic:op_quantized_linear"])
127125
define_operator("quantized_linear_per_tensor_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators/generic:op_quantized_linear"])
126+
127+
# quantized_conv2d_nchw_out and quantized_conv2d_nhwc_out need additional dependency for int16 support
128+
define_operator("quantized_conv2d_nchw_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators/generic:op_quantized_conv2d"])
129+
define_operator("quantized_conv2d_nhwc_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators/generic:op_quantized_conv2d"])

0 commit comments

Comments
 (0)