Skip to content

Commit 9c3e967

Browse files
ethansfngfacebook-github-bot
authored andcommitted
Add unittest for quantized_relu_out (#12499)
Summary: Add a unittest for quantized_relu_out Reviewed By: hsharma35 Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Differential Revision: D78175392 Pulled By: ethansfng
1 parent 953fa0e commit 9c3e967

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed

backends/cadence/hifi/operators/operators.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ namespace impl {
2323
namespace HiFi {
2424
namespace native {
2525

26+
void dequantize_per_tensor_out(
27+
::executorch::runtime::KernelRuntimeContext& ctx,
28+
const ::executorch::aten::Tensor& input,
29+
double scale,
30+
int64_t zero_point,
31+
int64_t quant_min,
32+
int64_t quant_max,
33+
::executorch::aten::ScalarType dtype,
34+
::executorch::aten::Tensor& out);
35+
2636
// Quantize the input tensor (PT2 version). Note that quant_<min,max> are not
2737
// used in any computation.
2838
void quantize_per_tensor_out(
@@ -42,6 +52,15 @@ ::executorch::aten::Tensor& div_out_mode(
4252
std::optional<std::string_view> mode,
4353
::executorch::aten::Tensor& out);
4454

55+
void quantized_relu_out(
56+
::executorch::runtime::KernelRuntimeContext& ctx,
57+
const ::executorch::aten::Tensor& input,
58+
const ::executorch::aten::Tensor& in_zero_point,
59+
const int64_t out_zero_point,
60+
const ::executorch::aten::Tensor& out_multiplier,
61+
const ::executorch::aten::Tensor& out_shift,
62+
::executorch::aten::Tensor& output);
63+
4564
void quantized_linear_out(
4665
__ET_UNUSED KernelRuntimeContext& ctx,
4766
const Tensor& in,
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
#include <sys/times.h>
11+
12+
#include <executorch/kernels/test/TestUtil.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
16+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
17+
#include <executorch/runtime/platform/runtime.h>
18+
19+
#include <executorch/backends/cadence/hifi/operators/operators.h>
20+
21+
namespace cadence {
22+
namespace impl {
23+
namespace HiFi {
24+
namespace native {
25+
namespace {
26+
27+
using ::executorch::aten::Scalar;
28+
using ::executorch::aten::ScalarType;
29+
using ::executorch::aten::Tensor;
30+
using ::executorch::aten::TensorImpl;
31+
using ::executorch::runtime::Error;
32+
using ::executorch::runtime::KernelRuntimeContext;
33+
using ::executorch::runtime::runtime_init;
34+
using ::executorch::runtime::testing::TensorFactory;
35+
using std::optional;
36+
using std::string_view;
37+
38+
class HiFiQuantizedReluTest : public OperatorTest {
39+
public:
40+
protected:
41+
void quantize_per_tensor_out(
42+
const Tensor& input,
43+
double scale,
44+
const int64_t zero_point,
45+
int64_t quant_min,
46+
int64_t quant_max,
47+
const ScalarType dtype,
48+
Tensor& out) {
49+
return ::cadence::impl::HiFi::native::quantize_per_tensor_out(
50+
context_, input, scale, zero_point, quant_min, quant_max, dtype, out);
51+
}
52+
53+
void quantized_relu_out(
54+
const Tensor& input,
55+
const Tensor& in_zero_point,
56+
const int64_t out_zero_point,
57+
const Tensor& out_multiplier,
58+
const Tensor& out_shift,
59+
Tensor& output) {
60+
return ::cadence::impl::HiFi::native::quantized_relu_out(
61+
context_,
62+
input,
63+
in_zero_point,
64+
out_zero_point,
65+
out_multiplier,
66+
out_shift,
67+
output);
68+
}
69+
70+
void dequantize_per_tensor_out(
71+
const Tensor& input,
72+
double scale,
73+
int64_t zero_point,
74+
int64_t quant_min,
75+
int64_t quant_max,
76+
ScalarType dtype,
77+
Tensor& out) {
78+
return ::cadence::impl::HiFi::native::dequantize_per_tensor_out(
79+
context_, input, scale, zero_point, quant_min, quant_max, dtype, out);
80+
}
81+
};
82+
83+
TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) {
84+
TensorFactory<ScalarType::Float> tf_float;
85+
TensorFactory<ScalarType::Char> tf_chars;
86+
const std::vector<int32_t> sizes{2, 3, 5, 6};
87+
Tensor input_float = tf_float.full(sizes, -1.0f);
88+
89+
Tensor quantized_input = tf_chars.zeros(sizes);
90+
double scale = 0.003921568859368563;
91+
int64_t zero_point = 127;
92+
int64_t quant_min = -128;
93+
int64_t quant_max = 127;
94+
quantize_per_tensor_out(
95+
input_float,
96+
scale,
97+
zero_point,
98+
quant_min,
99+
quant_max,
100+
ScalarType::Char,
101+
quantized_input);
102+
103+
Tensor quantized_output = tf_chars.zeros(sizes);
104+
Tensor in_zero_point = tf_chars.full({1}, 127);
105+
int64_t out_zero_point = -128;
106+
Tensor out_multiplier =
107+
TensorFactory<ScalarType::Int>().full({1}, 1077952640);
108+
Tensor out_shift = TensorFactory<ScalarType::Int>().full({1}, 5);
109+
quantized_relu_out(
110+
quantized_input,
111+
in_zero_point,
112+
out_zero_point,
113+
out_multiplier,
114+
out_shift,
115+
quantized_output);
116+
117+
Tensor output_float = tf_float.zeros(sizes);
118+
double dequant_scale = 0.000244140625;
119+
int64_t dequant_zero_point = -128;
120+
dequantize_per_tensor_out(
121+
quantized_output,
122+
dequant_scale,
123+
dequant_zero_point,
124+
quant_min,
125+
quant_max,
126+
ScalarType::Float,
127+
output_float);
128+
129+
EXPECT_TENSOR_EQ(output_float, tf_float.zeros(sizes));
130+
}
131+
132+
TEST_F(HiFiQuantizedReluTest, OneDimensionalTest) {
133+
TensorFactory<ScalarType::Float> tf_float;
134+
TensorFactory<ScalarType::Char> tf_chars;
135+
const std::vector<int32_t> sizes{56};
136+
Tensor input_float = tf_float.full(sizes, -1.0f);
137+
138+
Tensor quantized_input = tf_chars.zeros(sizes);
139+
double scale = 0.003921568859368563;
140+
int64_t zero_point = 127;
141+
int64_t quant_min = -128;
142+
int64_t quant_max = 127;
143+
quantize_per_tensor_out(
144+
input_float,
145+
scale,
146+
zero_point,
147+
quant_min,
148+
quant_max,
149+
ScalarType::Char,
150+
quantized_input);
151+
152+
Tensor quantized_output = tf_chars.zeros(sizes);
153+
Tensor in_zero_point = tf_chars.full({1}, 127);
154+
int64_t out_zero_point = -128;
155+
Tensor out_multiplier =
156+
TensorFactory<ScalarType::Int>().full({1}, 1077952640);
157+
Tensor out_shift = TensorFactory<ScalarType::Int>().full({1}, 5);
158+
quantized_relu_out(
159+
quantized_input,
160+
in_zero_point,
161+
out_zero_point,
162+
out_multiplier,
163+
out_shift,
164+
quantized_output);
165+
166+
Tensor output_float = tf_float.zeros(sizes);
167+
double dequant_scale = 0.000244140625;
168+
int64_t dequant_zero_point = -128;
169+
dequantize_per_tensor_out(
170+
quantized_output,
171+
dequant_scale,
172+
dequant_zero_point,
173+
quant_min,
174+
quant_max,
175+
ScalarType::Float,
176+
output_float);
177+
178+
EXPECT_TENSOR_EQ(output_float, tf_float.zeros(sizes));
179+
}
180+
181+
} // namespace
182+
} // namespace native
183+
} // namespace HiFi
184+
} // namespace impl
185+
} // namespace cadence

0 commit comments

Comments
 (0)