Skip to content

[ET] correcting cpu ref quantize_per_channel logic to align with ATen #12203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 14, 2025
68 changes: 23 additions & 45 deletions kernels/quantized/cpu/op_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/reduce_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <algorithm>
#include <cinttypes>
Expand Down Expand Up @@ -282,55 +281,34 @@ Tensor& quantize_per_channel_out(

check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out);

// a list contains all dimensions except axis
int64_t dims[kTensorDimensionLimit];
for (int64_t i = 0; i < input.dim() - 1; i++) {
if (i < axis) {
dims[i] = i;
} else {
dims[i] = i - 1;
}
}
const double* scale_data = scale.const_data_ptr<double>();
const int64_t* zero_point_data = zero_point.const_data_ptr<int64_t>();

std::optional<executorch::aten::ArrayRef<int64_t>> optional_dim_list{
executorch::aten::ArrayRef<int64_t>{dims, size_t(input.dim() - 1)}};

// Actual quantization logic
// input, out are the input and output tensors
// channel_ix is the index along the axis dimension. 0 <= channel_ix <
// input.size(axis).
// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix
// will be 0, 1, 2, ... C-1
// in_ix is the flat index of the element you are quantizing.
// in other words you are quantizing in_data[in_ix]
// High-performance single loop with direct channel calculation
#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \
case ScalarType::out_dtype: \
for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \
double _scale = scale_data[channel_ix]; \
int64_t _zero_point = zero_point_data[channel_ix]; \
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
apply_over_dim_list( \
[input_data_ptr, \
out_data_ptr, \
_scale, \
_zero_point, \
quant_min, \
quant_max](size_t in_ix) { \
out_data_ptr[in_ix] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
_scale, \
_zero_point, \
input_data_ptr[in_ix], \
quant_min, \
quant_max); \
}, \
input, \
optional_dim_list, \
channel_ix); \
case ScalarType::out_dtype: { \
auto* out_data_ptr = out.mutable_data_ptr<CTYPE_OUT>(); \
const auto* input_data_ptr = input.const_data_ptr<CTYPE_IN>(); \
const int64_t input_numel = input.numel(); \
const int64_t axis_size = input.size(axis); \
/* Calculate the stride pattern for efficient channel index calculation */ \
int64_t axis_block_size = 1; \
for (int64_t i = axis + 1; i < input.dim(); i++) { \
axis_block_size *= input.size(i); \
} \
break;
/* Single loop over all elements */ \
for (int64_t i = 0; i < input_numel; i++) { \
/* Calculate which channel this element belongs to */ \
int64_t channel_idx = (i / axis_block_size) % axis_size; \
/* Get quantization parameters for this channel */ \
double _scale = scale_data[channel_idx]; \
int64_t _zero_point = zero_point_data[channel_idx]; \
/* Apply quantization */ \
out_data_ptr[i] = quantize_val<CTYPE_OUT, CTYPE_IN>( \
_scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \
} \
} break;

#define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \
case ScalarType::in_dtype: \
switch (out.scalar_type()) { \
Expand Down
6 changes: 0 additions & 6 deletions kernels/quantized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,6 @@ _QUANT_OPS = (
),
op_target(
name = "op_quantize",
deps = [
"//executorch/kernels/portable/cpu/util:reduce_util",
],
_aten_mode_deps = [
"//executorch/kernels/portable/cpu/util:reduce_util_aten",
],
),
)

Expand Down
240 changes: 240 additions & 0 deletions kernels/quantized/test/op_quantize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,243 @@ TEST(OpQuantizeOutTest, QuantizePerChannel) {

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelAxis0) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_float.full({3, 2}, 4);
Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0});
Tensor zero_point = tf_long.make({3}, {100, 50, 25});
int64_t quant_min = 0;
int64_t quant_max = 255;

TensorFactory<ScalarType::Byte> tfo;
Tensor out = tfo.zeros({3, 2});
// Channel 0: 4 / 0.5 + 100 = 108
// Channel 1: 4 / 1.0 + 50 = 54
// Channel 2: 4 / 2.0 + 25 = 27
Tensor expected = tfo.make({3, 2}, {108, 108, 54, 54, 27, 27});
quantize_per_channel_out(
input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannel3D) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

// Test 3D tensor with axis=1 (middle dimension)
Tensor input = tf_float.full({2, 3, 4}, 6);
Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5});
Tensor zero_point = tf_long.make({3}, {10, 20, 30});
int64_t quant_min = -128;
int64_t quant_max = 127;

TensorFactory<ScalarType::Char> tfo;
Tensor out = tfo.zeros({2, 3, 4});
// Channel 0: 6 / 0.5 + 10 = 22
// Channel 1: 6 / 1.0 + 20 = 26
// Channel 2: 6 / 1.5 + 30 = 34
Tensor expected = tfo.make(
{2, 3, 4},
{
22, 22, 22, 22, // First batch, channel 0
26, 26, 26, 26, // First batch, channel 1
34, 34, 34, 34, // First batch, channel 2
22, 22, 22, 22, // Second batch, channel 0
26, 26, 26, 26, // Second batch, channel 1
34, 34, 34, 34 // Second batch, channel 2
});
quantize_per_channel_out(
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannel4D) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

// Test 4D tensor with axis=2 (typical conv weight layout: N,C,H,W)
Tensor input = tf_float.full({2, 2, 3, 2}, 8);
Tensor scale = tf_double.make({3}, {0.25, 0.5, 1.0});
Tensor zero_point = tf_long.make({3}, {0, 10, 20});
int64_t quant_min = -128;
int64_t quant_max = 127;

TensorFactory<ScalarType::Char> tfo;
Tensor out = tfo.zeros({2, 2, 3, 2});
// Channel 0: 8 / 0.25 + 0 = 32
// Channel 1: 8 / 0.5 + 10 = 26
// Channel 2: 8 / 1.0 + 20 = 28
std::vector<int8_t> expected_data;
for (int n = 0; n < 2; n++) {
for (int c = 0; c < 2; c++) {
for (int h = 0; h < 3; h++) {
for (int w = 0; w < 2; w++) {
int8_t val = (h == 0) ? 32 : (h == 1) ? 26 : 28;
expected_data.push_back(val);
}
}
}
}
Tensor expected = tfo.make({2, 2, 3, 2}, expected_data);
quantize_per_channel_out(
input, scale, zero_point, 2, quant_min, quant_max, ScalarType::Char, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelNegativeAxis) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_float.full({2, 3}, 5);
Tensor scale = tf_double.make({3}, {0.5, 1.0, 2.0});
Tensor zero_point = tf_long.make({3}, {0, 10, 20});
int64_t quant_min = 0;
int64_t quant_max = 255;

TensorFactory<ScalarType::Byte> tfo;
Tensor out = tfo.zeros({2, 3});
// Using axis=-1 should be equivalent to axis=1 for 2D tensor
// Channel 0: 5 / 0.5 + 0 = 10
// Channel 1: 5 / 1.0 + 10 = 15
// Channel 2: 5 / 2.0 + 20 = 22 (rounded from 22.5)
Tensor expected = tfo.make({2, 3}, {10, 15, 22, 10, 15, 22});
quantize_per_channel_out(
input,
scale,
zero_point,
-1,
quant_min,
quant_max,
ScalarType::Byte,
out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelSingleChannel) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_float.full({3, 1, 4}, 7);
Tensor scale = tf_double.make({1}, {0.5});
Tensor zero_point = tf_long.make({1}, {128});
int64_t quant_min = 0;
int64_t quant_max = 255;

TensorFactory<ScalarType::Byte> tfo;
Tensor out = tfo.zeros({3, 1, 4});
// Single channel: 7 / 0.5 + 128 = 142
Tensor expected = tfo.full({3, 1, 4}, 142);
quantize_per_channel_out(
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelDifferentInputTypes) {
TensorFactory<ScalarType::Double> tf_double_input;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_double_input.full({2, 2}, 3.14159);
Tensor scale = tf_double.make({2}, {0.01, 0.02});
Tensor zero_point = tf_long.make({2}, {0, 100});
int64_t quant_min = -128;
int64_t quant_max = 127;

TensorFactory<ScalarType::Char> tfo;
Tensor out = tfo.zeros({2, 2});
// Channel 0: 3.14159 / 0.01 + 0 = 314 -> clamped to 127
// Channel 1: 3.14159 / 0.02 + 100 = 257 -> clamped to 127
Tensor expected = tfo.make({2, 2}, {127, 127, 127, 127});
quantize_per_channel_out(
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelDifferentOutputTypes) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

Tensor input = tf_float.full({2, 2}, 10);
Tensor scale = tf_double.make({2}, {1.0, 2.0});
Tensor zero_point = tf_long.make({2}, {1000, 2000});
int64_t quant_min = -32768;
int64_t quant_max = 32767;

// Test with 16-bit output
TensorFactory<ScalarType::Short> tfo;
Tensor out = tfo.zeros({2, 2});
// Channel 0: 10 / 1.0 + 1000 = 1010
// Channel 1: 10 / 2.0 + 2000 = 2005
Tensor expected = tfo.make({2, 2}, {1010, 2005, 1010, 2005});
quantize_per_channel_out(
input,
scale,
zero_point,
1,
quant_min,
quant_max,
ScalarType::Short,
out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelMixedValues) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

// Test with different input values per position
Tensor input = tf_float.make({2, 3}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0});
Tensor scale = tf_double.make({3}, {0.5, 1.0, 1.5});
Tensor zero_point = tf_long.make({3}, {10, 20, 30});
int64_t quant_min = 0;
int64_t quant_max = 255;

TensorFactory<ScalarType::Byte> tfo;
Tensor out = tfo.zeros({2, 3});
// Row 0: [1.0/0.5+10, 2.0/1.0+20, 3.0/1.5+30] = [12, 22, 32]
// Row 1: [4.0/0.5+10, 5.0/1.0+20, 6.0/1.5+30] = [18, 25, 34]
Tensor expected = tfo.make({2, 3}, {12, 22, 32, 18, 25, 34});
quantize_per_channel_out(
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out);

EXPECT_TENSOR_EQ(out, expected);
}

TEST(OpQuantizeOutTest, QuantizePerChannelClampingBehavior) {
TensorFactory<ScalarType::Float> tf_float;
TensorFactory<ScalarType::Double> tf_double;
TensorFactory<ScalarType::Long> tf_long;

// Test values that will exceed quant_min/quant_max bounds
Tensor input = tf_float.make({1, 3}, {-100.0, 0.0, 100.0});
Tensor scale = tf_double.make({3}, {1.0, 1.0, 1.0});
Tensor zero_point = tf_long.make({3}, {0, 0, 0});
int64_t quant_min = -10;
int64_t quant_max = 10;

TensorFactory<ScalarType::Char> tfo;
Tensor out = tfo.zeros({1, 3});
// Values: [-100, 0, 100] should be clamped to [-10, 0, 10]
Tensor expected = tfo.make({1, 3}, {-10, 0, 10});
quantize_per_channel_out(
input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Char, out);

EXPECT_TENSOR_EQ(out, expected);
}
Loading