Skip to content

Commit 5f59b76

Browse files
authored
Fix uint16 support for quantize_per_tensor.
Differential Revision: D73517183 Pull Request resolved: #10398
1 parent a82fa8f commit 5f59b76

File tree

2 files changed

+80
-11
lines changed

2 files changed

+80
-11
lines changed

backends/cadence/hifi/operators/op_quantize_per_tensor.cpp

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,26 @@ namespace cadence {
2020
namespace impl {
2121
namespace HiFi {
2222
namespace native {
23-
23+
namespace {
2424
using ::executorch::aten::ScalarType;
2525
using ::executorch::aten::Tensor;
2626
using ::executorch::runtime::KernelRuntimeContext;
2727

28+
// Add checks for dtype quant min/max bounds.
29+
template <typename T>
30+
void check_quant_min_and_max(
31+
KernelRuntimeContext& ctx,
32+
const int64_t quant_min,
33+
const int64_t quant_max) {
34+
ET_KERNEL_CHECK(
35+
ctx,
36+
std::numeric_limits<T>::min() == quant_min &&
37+
std::numeric_limits<T>::max() == quant_max,
38+
InvalidArgument, );
39+
}
40+
41+
} // namespace
42+
2843
// Quantize the input tensor (PT2 version). Note that quant_<min,max> are not
2944
// used in any computation.
3045
void quantize_per_tensor_out(
@@ -36,15 +51,43 @@ void quantize_per_tensor_out(
3651
__ET_UNUSED int64_t quant_max,
3752
const ScalarType dtype,
3853
Tensor& out) {
39-
// Add checks for dtype quant min/max bounds.
40-
ET_SWITCH_REALB_TYPES(
41-
out.scalar_type(), ctx, "quantize_per_tensor", OUT_DTYPE, [&]() {
42-
ET_KERNEL_CHECK(
43-
ctx,
44-
std::numeric_limits<OUT_DTYPE>::min() == quant_min &&
45-
std::numeric_limits<OUT_DTYPE>::max() == quant_max,
46-
InvalidArgument, );
47-
});
54+
// Check for input scalar type.
55+
ET_KERNEL_CHECK_MSG(
56+
ctx,
57+
input.scalar_type() == ScalarType::Float,
58+
InvalidType,
59+
,
60+
"Input tensor for quantize_per_tensor.out should be type %s, but got %s",
61+
::torch::executor::toString(ScalarType::Float),
62+
::torch::executor::toString(input.scalar_type()));
63+
64+
// Check quant min/max for output types.
65+
switch (out.scalar_type()) {
66+
case ScalarType::Byte:
67+
check_quant_min_and_max<uint8_t>(ctx, quant_min, quant_max);
68+
break;
69+
case ScalarType::Char:
70+
check_quant_min_and_max<int8_t>(ctx, quant_min, quant_max);
71+
break;
72+
case ScalarType::Short:
73+
check_quant_min_and_max<int16_t>(ctx, quant_min, quant_max);
74+
break;
75+
case ScalarType::Bits16:
76+
case ScalarType::UInt16:
77+
check_quant_min_and_max<uint16_t>(ctx, quant_min, quant_max);
78+
break;
79+
case ScalarType::Int:
80+
check_quant_min_and_max<int32_t>(ctx, quant_min, quant_max);
81+
break;
82+
default:
83+
ET_KERNEL_CHECK_MSG(
84+
ctx,
85+
false,
86+
InvalidType,
87+
,
88+
"Unhandled output dtype %s",
89+
::torch::executor::toString(out.scalar_type()));
90+
}
4891

4992
const float* input_data = input.const_data_ptr<float>();
5093
const size_t numel = out.numel();

backends/cadence/hifi/operators/tests/test_op_quantize_per_tensor.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ TEST_F(HiFiQuantizePerTensorTest, ThrowKernelFailureForQuantMaxLessThanLimit) {
106106
out));
107107
}
108108

109-
TEST_F(HiFiQuantizePerTensorTest, CheckSingleElementQuantize) {
109+
TEST_F(HiFiQuantizePerTensorTest, CheckSingleElementIntQuantize) {
110110
TensorFactory<ScalarType::Float> tf;
111111
const std::vector<int> sizes{1};
112112
constexpr ScalarType kOutDtype = ScalarType::Int;
@@ -132,6 +132,32 @@ TEST_F(HiFiQuantizePerTensorTest, CheckSingleElementQuantize) {
132132
EXPECT_TENSOR_EQ(out, tf_out.make(sizes, {kExpectedOutputValue}));
133133
}
134134

135+
TEST_F(HiFiQuantizePerTensorTest, CheckSingleElementUInt16Quantize) {
136+
TensorFactory<ScalarType::Float> tf;
137+
const std::vector<int> sizes{1};
138+
constexpr ScalarType kOutDtype = ScalarType::UInt16;
139+
TensorFactory<kOutDtype> tf_out;
140+
Tensor out = tf_out.zeros(sizes);
141+
// Some arbitrary values for scalar args.
142+
constexpr double kScale = 0.01;
143+
constexpr int64_t kZeroPoint = 32768;
144+
constexpr int64_t kQuantMin = std::numeric_limits<uint16_t>::min();
145+
constexpr int64_t kQuantMax = std::numeric_limits<uint16_t>::max();
146+
constexpr float kInputValue = 100.0f;
147+
constexpr uint16_t kExpectedOutputValue =
148+
static_cast<uint16_t>(kInputValue / kScale + kZeroPoint);
149+
150+
quantize_per_tensor_out(
151+
tf.make(sizes, {kInputValue}),
152+
kScale,
153+
kZeroPoint,
154+
kQuantMin,
155+
kQuantMax,
156+
kOutDtype,
157+
out);
158+
EXPECT_TENSOR_EQ(out, tf_out.make(sizes, {kExpectedOutputValue}));
159+
}
160+
135161
} // namespace
136162
} // namespace native
137163
} // namespace HiFi

0 commit comments

Comments
 (0)