Skip to content

Commit 0fca5ef

Browse files
quic-tirupathTirupathi Reddy T
authored andcommitted
[QNN EP] Fix 16x16 MatMul translation
- QNN's 16x16 FC doesn't support asymmetric int16 weight - QNN's 16x16 MatMul doesn't support asymmetric int16 weight initializer. - Insert Convert Op to convert from asymmetric uint16 weight to symmetric int16 weight. - Add unit tests to verify 16x16 MatMul translations.
1 parent 915a999 commit 0fca5ef

File tree

2 files changed

+129
-56
lines changed

2 files changed

+129
-56
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc

Lines changed: 77 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -49,50 +49,6 @@ class MatMulOpBuilder : public BaseOpBuilder {
4949
};
5050

5151
namespace {
52-
53-
// Inserts a QNN Convert operator to convert from one quantization type (e.g., uint16) to another (e.g., uint8).
54-
Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper,
55-
const std::string& convert_input_name,
56-
const std::string& convert_output_name,
57-
Qnn_DataType_t input_qnn_data_type,
58-
Qnn_DataType_t output_qnn_data_type,
59-
int32_t input_offset,
60-
float input_scale,
61-
const std::vector<uint32_t>& output_shape,
62-
bool do_op_validation) {
63-
// Assume input is already handled.
64-
float qmin = 0.0f;
65-
float qmax = 255.0f;
66-
ORT_RETURN_IF_ERROR(qnn::utils::GetQminQmax(input_qnn_data_type, qmin, qmax));
67-
double value_min = qnn::utils::Dequantize(input_offset, input_scale, qmin);
68-
double value_max = qnn::utils::Dequantize(input_offset, input_scale, qmax);
69-
float scale = 0.0f;
70-
int32_t offset = 0;
71-
ORT_RETURN_IF_ERROR(qnn::utils::GetQuantParams(static_cast<float>(value_min),
72-
static_cast<float>(value_max),
73-
output_qnn_data_type,
74-
scale,
75-
offset));
76-
77-
std::vector<uint32_t> output_shape_copy = output_shape;
78-
QnnTensorWrapper convert_output_tensorwrapper(convert_output_name,
79-
QNN_TENSOR_TYPE_NATIVE,
80-
output_qnn_data_type,
81-
QnnQuantParamsWrapper(scale, offset),
82-
std::move(output_shape_copy));
83-
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(convert_output_tensorwrapper)), "Failed to add tensor.");
84-
85-
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(convert_output_name,
86-
QNN_OP_PACKAGE_NAME_QTI_AISW,
87-
"Convert",
88-
{convert_input_name},
89-
{convert_output_name},
90-
{},
91-
do_op_validation),
92-
"Failed to add node.");
93-
return Status::OK();
94-
}
95-
9652
inline bool IsQuant16bit(Qnn_DataType_t qnn_data_type) {
9753
return qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16 || qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16;
9854
}
@@ -253,7 +209,8 @@ Status MatMulOpBuilder::ProcessInputsForQnnMatMul(QnnModelWrapper& qnn_model_wra
253209
}
254210
input_names.emplace_back(input_1_name);
255211

256-
// Workaround that inserts a QNN Convert op before input[1] (converts from quantized uint16 to quantized uint8)
212+
// Workaround that inserts a QNN Convert op before input[1] (converts from quantized uint16 to quantized uint8
213+
// OR converts from asymmetric quantized uint16 to symmetric quantized uint16)
257214
// to avoid a QNN validation failure.
258215
//
259216
// QNN graph WITHOUT workaround (fails validation):
@@ -262,12 +219,18 @@ Status MatMulOpBuilder::ProcessInputsForQnnMatMul(QnnModelWrapper& qnn_model_wra
262219
// |
263220
// input_1_uint16 -----+
264221
//
265-
// QNN graph WITH workaround (passes validation):
222+
// For Dynamic weights, QNN graph WITH workaround (passes validation):
266223
// input_0_uint16 ----------------------> MatMul ---> output_uint16
267224
// ^
268225
// |
269226
// input_1_uint16 --> Convert(to uint8) --+
270-
if (!input_info_0.is_initializer && !input_info_1.is_initializer &&
227+
//
228+
// For Static weights, QNN graph WITH workaround (passes validation):
229+
// input_0_uint16 ------------------------------> MatMul ---> output_uint16
230+
// ^
231+
// |
232+
// input_1_uint16 --> Convert(to symmetric int16) --+
233+
if (!input_info_0.is_initializer &&
271234
input_info_0.qnn_data_type == input_info_1.qnn_data_type &&
272235
input_info_0.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) {
273236
ORT_RETURN_IF_NOT(input_info_1.quant_param.IsPerTensor(),
@@ -282,15 +245,29 @@ Status MatMulOpBuilder::ProcessInputsForQnnMatMul(QnnModelWrapper& qnn_model_wra
282245
if (reshape_input_1) {
283246
input_1_shape = {input_info_1.shape[0], 1};
284247
}
285-
ORT_RETURN_IF_ERROR(InsertConvertOp(qnn_model_wrapper,
286-
convert_input_name,
287-
convert_output_name,
288-
input_info_1.qnn_data_type,
289-
QNN_DATATYPE_UFIXED_POINT_8,
290-
quant_param.scaleOffsetEncoding.offset,
291-
quant_param.scaleOffsetEncoding.scale,
292-
input_1_shape,
293-
do_op_validation));
248+
if (!input_info_1.is_initializer) {
249+
ORT_RETURN_IF_ERROR(utils::InsertConvertOp(qnn_model_wrapper,
250+
convert_input_name,
251+
convert_output_name,
252+
input_info_1.qnn_data_type,
253+
QNN_DATATYPE_UFIXED_POINT_8,
254+
quant_param.scaleOffsetEncoding.offset,
255+
quant_param.scaleOffsetEncoding.scale,
256+
input_1_shape,
257+
false, // asymmetric
258+
do_op_validation));
259+
} else {
260+
ORT_RETURN_IF_ERROR(utils::InsertConvertOp(qnn_model_wrapper,
261+
convert_input_name,
262+
convert_output_name,
263+
input_info_1.qnn_data_type,
264+
QNN_DATATYPE_SFIXED_POINT_16,
265+
quant_param.scaleOffsetEncoding.offset,
266+
quant_param.scaleOffsetEncoding.scale,
267+
input_1_shape,
268+
true, // symmetric
269+
do_op_validation));
270+
}
294271
input_names.push_back(convert_output_name);
295272
}
296273
return Status::OK();
@@ -355,6 +332,50 @@ Status MatMulOpBuilder::ProcessInputsForQnnFullyConnected(QnnModelWrapper& qnn_m
355332
qnn_model_wrapper.IsGraphInput(org_input_1_name), false));
356333
}
357334
input_names.emplace_back(input_1_name);
335+
336+
// Workaround that inserts a QNN Convert op before input[1] (converts from quantized uint16 to signed symmetric int16)
337+
// to avoid a QNN validation failure.
338+
//
339+
// QNN graph WITHOUT workaround (fails validation):
340+
// input_0_uint16 ---> FC ---> output_uint16
341+
// ^
342+
// |
343+
// input_1_uint16 -----+
344+
//
345+
// QNN graph WITH workaround (passes validation):
346+
// input_0_uint16 ----------------------> FC ---> output_uint16
347+
// ^
348+
// |
349+
// input_1_uint16 --> Convert(to int16) --+
350+
351+
std::string weight_input_name = input_names.back();
352+
const auto& weight_tensor_wrapper = qnn_model_wrapper.GetQnnTensorWrapper(weight_input_name);
353+
354+
if (weight_tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_UFIXED_POINT_16) {
355+
const auto& quant_param_wrapper = weight_tensor_wrapper.GetQnnQuantParams();
356+
const Qnn_QuantizeParams_t& quant_param = quant_param_wrapper.Get();
357+
const auto& transformed_input1_shape = weight_tensor_wrapper.GetTensorDims();
358+
359+
ORT_RETURN_IF_NOT(quant_param_wrapper.IsPerTensor(),
360+
"FC's INT16 weight inputs only support INT16 per-tensor quantization");
361+
362+
// Pop Conv weight. Insert Convert op after Weight
363+
input_names.pop_back();
364+
const std::string& conv_output_name = node_unit.Outputs()[0].node_arg.Name();
365+
std::string convert_output_name = weight_input_name + "_convert_" + conv_output_name;
366+
367+
ORT_RETURN_IF_ERROR(utils::InsertConvertOp(qnn_model_wrapper,
368+
weight_input_name,
369+
convert_output_name,
370+
QNN_DATATYPE_UFIXED_POINT_16,
371+
QNN_DATATYPE_SFIXED_POINT_16,
372+
quant_param.scaleOffsetEncoding.offset,
373+
quant_param.scaleOffsetEncoding.scale,
374+
transformed_input1_shape,
375+
true, // Symmetric
376+
do_op_validation));
377+
input_names.push_back(convert_output_name);
378+
}
358379
return Status::OK();
359380
}
360381

onnxruntime/test/providers/qnn/matmul_test.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,58 @@ TEST_F(QnnHTPBackendTests, MatMulOp_QDQ_Regression_uint16_dynamic_inputs) {
340340
}
341341
}
342342

343+
// Tests MatMul with two uint16 (quantized) inputs with weight as static.
344+
// This exercises a workaround in QNN EP that inserts a QNN Convert op before input[1] (converts from uint16 to sint16).
345+
// This workaround prevents a validation error for this specific MatMul configuration.
346+
// Got specific shapes and input ranges (quant params) from customer model.
347+
TEST_F(QnnHTPBackendTests, MatMulOp_QDQ_Regression_uint16_static_weight) {
348+
ProviderOptions provider_options;
349+
provider_options["backend_type"] = "htp";
350+
provider_options["offload_graph_io_quantization"] = "0";
351+
352+
// Test with rank 4 inputs
353+
{
354+
std::vector<int64_t> shape_0 = {1, 12, 512, 96};
355+
TestInputDef<float> input0_def(
356+
{1, 12, 512, 96}, false,
357+
GetFloatDataInRange(-5.087f, 4.992f,
358+
static_cast<size_t>(std::accumulate(shape_0.begin(), shape_0.end(), static_cast<int64_t>(1),
359+
std::multiplies<int64_t>()))));
360+
std::vector<int64_t> shape_1 = {1, 12, 96, 512};
361+
TestInputDef<float> input1_def(
362+
shape_1, true,
363+
GetFloatDataInRange(-6.772f, 7.258f,
364+
static_cast<size_t>(std::accumulate(shape_1.begin(), shape_1.end(), static_cast<int64_t>(1),
365+
std::multiplies<int64_t>()))));
366+
367+
TestQDQModelAccuracy(
368+
BuildMatMulOpTestCase(input0_def, input1_def),
369+
BuildMatMulOpQDQTestCase<uint16_t, uint16_t, uint16_t>(input0_def, input1_def, false),
370+
provider_options, 21, ExpectedEPNodeAssignment::All, QDQTolerance());
371+
}
372+
373+
// Test with input[1] as rank 1
374+
{
375+
std::vector<int64_t> shape_0 = {1, 12, 512, 96};
376+
TestInputDef<float> input0_def(
377+
{1, 12, 512, 96}, false,
378+
GetFloatDataInRange(-5.087f, 4.992f,
379+
static_cast<size_t>(std::accumulate(shape_0.begin(), shape_0.end(), static_cast<int64_t>(1),
380+
std::multiplies<int64_t>()))));
381+
std::vector<int64_t> shape_1 = {96};
382+
TestInputDef<float> input1_def(
383+
shape_1, true,
384+
GetFloatDataInRange(-6.772f, 7.258f,
385+
static_cast<size_t>(std::accumulate(shape_1.begin(), shape_1.end(), static_cast<int64_t>(1),
386+
std::multiplies<int64_t>()))));
387+
388+
TestQDQModelAccuracy(
389+
BuildMatMulOpTestCase(input0_def, input1_def),
390+
BuildMatMulOpQDQTestCase<uint16_t, uint16_t, uint16_t>(input0_def, input1_def, false),
391+
provider_options, 21, ExpectedEPNodeAssignment::All, QDQTolerance());
392+
}
393+
}
394+
343395
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
344396

345397
} // namespace test

0 commit comments

Comments
 (0)