@@ -20,11 +20,26 @@ namespace cadence {
20
20
namespace impl {
21
21
namespace HiFi {
22
22
namespace native {
23
-
23
+ namespace {
24
24
using ::executorch::aten::ScalarType;
25
25
using ::executorch::aten::Tensor;
26
26
using ::executorch::runtime::KernelRuntimeContext;
27
27
28
+ // Add checks for dtype quant min/max bounds.
29
+ template <typename T>
30
+ void check_quant_min_and_max (
31
+ KernelRuntimeContext& ctx,
32
+ const int64_t quant_min,
33
+ const int64_t quant_max) {
34
+ ET_KERNEL_CHECK (
35
+ ctx,
36
+ std::numeric_limits<T>::min () == quant_min &&
37
+ std::numeric_limits<T>::max () == quant_max,
38
+ InvalidArgument, );
39
+ }
40
+
41
+ } // namespace
42
+
28
43
// Quantize the input tensor (PT2 version). Note that quant_<min,max> are not
29
44
// used in any computation.
30
45
void quantize_per_tensor_out (
@@ -36,15 +51,43 @@ void quantize_per_tensor_out(
36
51
__ET_UNUSED int64_t quant_max,
37
52
const ScalarType dtype,
38
53
Tensor& out) {
39
- // Add checks for dtype quant min/max bounds.
40
- ET_SWITCH_REALB_TYPES (
41
- out.scalar_type (), ctx, " quantize_per_tensor" , OUT_DTYPE, [&]() {
42
- ET_KERNEL_CHECK (
43
- ctx,
44
- std::numeric_limits<OUT_DTYPE>::min () == quant_min &&
45
- std::numeric_limits<OUT_DTYPE>::max () == quant_max,
46
- InvalidArgument, );
47
- });
54
+ // Check for input scalar type.
55
+ ET_KERNEL_CHECK_MSG (
56
+ ctx,
57
+ input.scalar_type () == ScalarType::Float,
58
+ InvalidType,
59
+ ,
60
+ " Input tensor for quantize_per_tensor.out should be type %s, but got %s" ,
61
+ ::torch::executor::toString (ScalarType::Float),
62
+ ::torch::executor::toString(input.scalar_type()));
63
+
64
+ // Check quant min/max for output types.
65
+ switch (out.scalar_type ()) {
66
+ case ScalarType::Byte:
67
+ check_quant_min_and_max<uint8_t >(ctx, quant_min, quant_max);
68
+ break ;
69
+ case ScalarType::Char:
70
+ check_quant_min_and_max<int8_t >(ctx, quant_min, quant_max);
71
+ break ;
72
+ case ScalarType::Short:
73
+ check_quant_min_and_max<int16_t >(ctx, quant_min, quant_max);
74
+ break ;
75
+ case ScalarType::Bits16:
76
+ case ScalarType::UInt16:
77
+ check_quant_min_and_max<uint16_t >(ctx, quant_min, quant_max);
78
+ break ;
79
+ case ScalarType::Int:
80
+ check_quant_min_and_max<int32_t >(ctx, quant_min, quant_max);
81
+ break ;
82
+ default :
83
+ ET_KERNEL_CHECK_MSG (
84
+ ctx,
85
+ false ,
86
+ InvalidType,
87
+ ,
88
+ " Unhandled output dtype %s" ,
89
+ ::torch::executor::toString (out.scalar_type()));
90
+ }
48
91
49
92
const float * input_data = input.const_data_ptr <float >();
50
93
const size_t numel = out.numel ();
0 commit comments