Skip to content

Commit 91e9a52

Browse files
ethansfngfacebook-github-bot
authored andcommitted
Add unittest for quantized_relu_out, dequantize_per_tensor_out (#12499)
Summary: Add a unittest for quantized_relu_out, dequantize_per_tensor_out Reviewed By: hsharma35 Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Differential Revision: D78175392 Pulled By: ethansfng
1 parent 0012ffa commit 91e9a52

File tree

3 files changed

+232
-0
lines changed

3 files changed

+232
-0
lines changed

backends/cadence/hifi/operators/operators.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ namespace impl {
2323
namespace HiFi {
2424
namespace native {
2525

26+
void dequantize_per_tensor_out(
27+
::executorch::runtime::KernelRuntimeContext& ctx,
28+
const ::executorch::aten::Tensor& input,
29+
double scale,
30+
int64_t zero_point,
31+
int64_t quant_min,
32+
int64_t quant_max,
33+
::executorch::aten::ScalarType dtype,
34+
::executorch::aten::Tensor& out);
35+
2636
// Quantize the input tensor (PT2 version). Note that quant_<min,max> are not
2737
// used in any computation.
2838
void quantize_per_tensor_out(
@@ -42,6 +52,15 @@ ::executorch::aten::Tensor& div_out_mode(
4252
std::optional<std::string_view> mode,
4353
::executorch::aten::Tensor& out);
4454

55+
void quantized_relu_out(
56+
::executorch::runtime::KernelRuntimeContext& ctx,
57+
const ::executorch::aten::Tensor& input,
58+
const ::executorch::aten::Tensor& in_zero_point,
59+
const int64_t out_zero_point,
60+
const ::executorch::aten::Tensor& out_multiplier,
61+
const ::executorch::aten::Tensor& out_shift,
62+
::executorch::aten::Tensor& output);
63+
4564
void quantized_linear_out(
4665
__ET_UNUSED KernelRuntimeContext& ctx,
4766
const Tensor& in,
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
#include <sys/times.h>
11+
12+
#include <executorch/kernels/test/TestUtil.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
16+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
17+
#include <executorch/runtime/platform/runtime.h>
18+
19+
#include <executorch/backends/cadence/hifi/operators/operators.h>
20+
21+
namespace cadence {
22+
namespace impl {
23+
namespace HiFi {
24+
namespace native {
25+
namespace {
26+
27+
using ::executorch::aten::Scalar;
28+
using ::executorch::aten::ScalarType;
29+
using ::executorch::aten::Tensor;
30+
using ::executorch::aten::TensorImpl;
31+
using ::executorch::runtime::Error;
32+
using ::executorch::runtime::KernelRuntimeContext;
33+
using ::executorch::runtime::runtime_init;
34+
using ::executorch::runtime::testing::TensorFactory;
35+
using std::optional;
36+
using std::string_view;
37+
38+
class HiFiDequantizePerTensorTest : public OperatorTest {
39+
public:
40+
protected:
41+
void dequantize_per_tensor_out(
42+
const Tensor& input,
43+
double scale,
44+
int64_t zero_point,
45+
int64_t quant_min,
46+
int64_t quant_max,
47+
ScalarType dtype,
48+
Tensor& out) {
49+
return ::cadence::impl::HiFi::native::dequantize_per_tensor_out(
50+
context_, input, scale, zero_point, quant_min, quant_max, dtype, out);
51+
}
52+
};
53+
54+
TEST_F(HiFiDequantizePerTensorTest, MultiDimensionalTest) {
55+
TensorFactory<ScalarType::Float> tf_float;
56+
TensorFactory<ScalarType::Char> tf_chars;
57+
const std::vector<int32_t> sizes{2, 3, 5, 6};
58+
Tensor quantized_tensor = tf_chars.full(sizes, -128);
59+
Tensor output_float = tf_float.zeros(sizes);
60+
double dequant_scale = 0.000244140625;
61+
int64_t dequant_zero_point = -128;
62+
int64_t quant_min = -128;
63+
int64_t quant_max = 127;
64+
65+
dequantize_per_tensor_out(
66+
quantized_tensor,
67+
dequant_scale,
68+
dequant_zero_point,
69+
quant_min,
70+
quant_max,
71+
ScalarType::Float,
72+
output_float);
73+
74+
EXPECT_TENSOR_EQ(output_float, tf_float.zeros(sizes));
75+
}
76+
77+
TEST_F(HiFiDequantizePerTensorTest, OneDimensionalTest) {
78+
TensorFactory<ScalarType::Float> tf_float;
79+
TensorFactory<ScalarType::Char> tf_chars;
80+
const std::vector<int32_t> sizes{56};
81+
Tensor quantized_tensor = tf_chars.full(sizes, -128);
82+
Tensor output_float = tf_float.zeros(sizes);
83+
double dequant_scale = 0.000244140625;
84+
int64_t dequant_zero_point = -128;
85+
int64_t quant_min = -128;
86+
int64_t quant_max = 127;
87+
88+
dequantize_per_tensor_out(
89+
quantized_tensor,
90+
dequant_scale,
91+
dequant_zero_point,
92+
quant_min,
93+
quant_max,
94+
ScalarType::Float,
95+
output_float);
96+
97+
EXPECT_TENSOR_EQ(output_float, tf_float.zeros(sizes));
98+
}
99+
100+
} // namespace
101+
} // namespace native
102+
} // namespace HiFi
103+
} // namespace impl
104+
} // namespace cadence
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
#include <sys/times.h>
11+
12+
#include <executorch/kernels/test/TestUtil.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
16+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
17+
#include <executorch/runtime/platform/runtime.h>
18+
19+
#include <executorch/backends/cadence/hifi/operators/operators.h>
20+
21+
namespace cadence {
22+
namespace impl {
23+
namespace HiFi {
24+
namespace native {
25+
namespace {
26+
27+
using ::executorch::aten::Scalar;
28+
using ::executorch::aten::ScalarType;
29+
using ::executorch::aten::Tensor;
30+
using ::executorch::aten::TensorImpl;
31+
using ::executorch::runtime::Error;
32+
using ::executorch::runtime::KernelRuntimeContext;
33+
using ::executorch::runtime::runtime_init;
34+
using ::executorch::runtime::testing::TensorFactory;
35+
using std::optional;
36+
using std::string_view;
37+
38+
class HiFiQuantizedReluTest : public OperatorTest {
39+
public:
40+
protected:
41+
void quantized_relu_out(
42+
const Tensor& input,
43+
const Tensor& in_zero_point,
44+
const int64_t out_zero_point,
45+
const Tensor& out_multiplier,
46+
const Tensor& out_shift,
47+
Tensor& output) {
48+
return ::cadence::impl::HiFi::native::quantized_relu_out(
49+
context_,
50+
input,
51+
in_zero_point,
52+
out_zero_point,
53+
out_multiplier,
54+
out_shift,
55+
output);
56+
}
57+
};
58+
59+
TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) {
60+
TensorFactory<ScalarType::Char> tf_chars;
61+
const std::vector<int32_t> sizes{2, 3, 5, 6};
62+
Tensor quantized_input = tf_chars.full(sizes, -128);
63+
Tensor quantized_output = tf_chars.full(sizes, 100);
64+
Tensor in_zero_point = tf_chars.full({1}, 127);
65+
int64_t out_zero_point = -128;
66+
Tensor out_multiplier =
67+
TensorFactory<ScalarType::Int>().full({1}, 1077952640);
68+
Tensor out_shift = TensorFactory<ScalarType::Int>().full({1}, 5);
69+
70+
quantized_relu_out(
71+
quantized_input,
72+
in_zero_point,
73+
out_zero_point,
74+
out_multiplier,
75+
out_shift,
76+
quantized_output);
77+
78+
Tensor expected_output = tf_chars.full(sizes, -128);
79+
EXPECT_TENSOR_EQ(quantized_output, expected_output);
80+
}
81+
82+
TEST_F(HiFiQuantizedReluTest, OneDimensionalTest) {
83+
TensorFactory<ScalarType::Char> tf_chars;
84+
const std::vector<int32_t> sizes{56};
85+
Tensor quantized_input = tf_chars.full(sizes, -128);
86+
Tensor quantized_output = tf_chars.full(sizes, 100);
87+
Tensor in_zero_point = tf_chars.full({1}, 127);
88+
int64_t out_zero_point = -128;
89+
Tensor out_multiplier =
90+
TensorFactory<ScalarType::Int>().full({1}, 1077952640);
91+
Tensor out_shift = TensorFactory<ScalarType::Int>().full({1}, 5);
92+
93+
quantized_relu_out(
94+
quantized_input,
95+
in_zero_point,
96+
out_zero_point,
97+
out_multiplier,
98+
out_shift,
99+
quantized_output);
100+
101+
Tensor expected_output = tf_chars.full(sizes, -128);
102+
EXPECT_TENSOR_EQ(quantized_output, expected_output);
103+
}
104+
105+
} // namespace
106+
} // namespace native
107+
} // namespace HiFi
108+
} // namespace impl
109+
} // namespace cadence

0 commit comments

Comments
 (0)