Skip to content

Commit 2e8583a

Browse files
qti-yuduo1duo
authored andcommitted
[QNN-EP] Fuse pre-scale (multiply) into Softmax op
1 parent d23eb9e commit 2e8583a

File tree

6 files changed

+409
-1
lines changed

6 files changed

+409
-1
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ endif()
724724
# or reduced op builds.
725725
if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD)
726726
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/*)
727+
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/qnn_node_group/*)
727728
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_qnn)
728729
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_qnn)
729730
if(NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)

onnxruntime/core/optimizer/bias_softmax_fusion.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node
135135
new_axis = (int)HandleNegativeAxis(axis, rank);
136136

137137
// The axis attribute for Softmax in OpSet-11 and OpSet-13 are different.
138-
// Details in function documentatin.
138+
// Details in function documentation.
139139
if (is_since_opset_13 && new_axis != rank - 1) return false;
140140

141141
int singlebatch_rank = rank - new_axis;

onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h"
1616
#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
1717
#include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h"
18+
#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h"
1819
#include "core/providers/qnn/builder/qnn_utils.h"
1920
#include "core/providers/qnn/ort_api.h"
2021

@@ -90,6 +91,7 @@ static std::unique_ptr<IQnnNodeGroup> TryQnnFusions(
9091
{"DequantizeLinear", DQQFusion::TryFusion},
9192
{"HardSigmoid", HardSigmoidMulFusion::TryFusion},
9293
{"Gemm", ReshapeGemmFusion::TryFusion},
94+
{"Mul", ScaleSoftmaxFusion::TryFusion},
9395
};
9496

9597
// For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes).
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h"
5+
6+
#include <gsl/gsl>
7+
#include <cassert>
8+
#include <optional>
9+
#include <utility>
10+
#include <string>
11+
#include <array>
12+
#include <vector>
13+
14+
#include "core/providers/qnn/builder/qnn_utils.h"
15+
#include "core/providers/qnn/builder/op_builder_factory.h"
16+
#include "core/providers/qnn/builder/qnn_node_group/utils.h"
17+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
18+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
19+
20+
namespace onnxruntime {
21+
namespace qnn {
22+
namespace {
23+
24+
constexpr char kOpMul[] = "Mul";
25+
constexpr char kOpSoftmax[] = "Softmax";
26+
27+
/// @brief Get the index of the scalar input in the mul node
28+
/// @param mul Multiply node unit
29+
/// @return The index of the scalar input (0 or 1) if found, otherwise std::nullopt
30+
std::optional<size_t> GetMulScalarInputIndex(const NodeUnit* mul) {
31+
const NodeArg* mul_x = mul->GetNode().InputDefs()[0];
32+
const NodeArg* mul_y = mul->GetNode().InputDefs()[1];
33+
auto mul_x_shape = utils::GetTensorProtoShape(*mul_x->Shape());
34+
auto mul_y_shape = utils::GetTensorProtoShape(*mul_y->Shape());
35+
bool is_x_scalar = mul_x_shape.NumDimensions() == 0;
36+
bool is_y_scalar = mul_y_shape.NumDimensions() == 0;
37+
if (!is_x_scalar && !is_y_scalar) {
38+
return std::nullopt;
39+
}
40+
return is_y_scalar ? 1 : 0;
41+
}
42+
43+
/// @brief Get the axis for softmax
44+
/// @param node_units The node units containing the softmax and mul nodes
45+
/// @return The axis for softmax
46+
std::optional<uint32_t> GetPositiveSoftmaxAxis(const NodeUnit* mul, const NodeUnit* softmax) {
47+
NodeAttrHelper softmax_attr_helper(softmax->GetNode());
48+
std::optional<int64_t> param_axis = softmax_attr_helper.GetInt64(QNN_OP_SOFTMAX_PARAM_AXIS);
49+
if (!param_axis.has_value()) {
50+
return std::nullopt;
51+
}
52+
int64_t axis_value = param_axis.value();
53+
if (axis_value < 0) {
54+
size_t input_scale_index = GetMulScalarInputIndex(mul).value();
55+
size_t input_other_index = 1 - input_scale_index;
56+
int rank = mul->GetNode().InputDefs()[input_other_index]->Shape()->dim_size();
57+
axis_value += static_cast<int64_t>(rank);
58+
}
59+
return static_cast<uint32_t>(axis_value);
60+
}
61+
62+
/// @brief Identify scalar input from mul node if present
63+
/// @param mul Multiply node unit
64+
/// @return The scalar input float value if found, otherwise std::nullopt
65+
std::optional<float> ExtractScalarValueFromMul(const GraphViewer& graph_viewer, const NodeUnit* mul) {
66+
std::optional<size_t> input_scale_index = GetMulScalarInputIndex(mul);
67+
if (!input_scale_index.has_value()) {
68+
return std::nullopt;
69+
}
70+
const NodeArg* scalar_arg = mul->GetNode().InputDefs()[input_scale_index.value()];
71+
if (!graph_viewer.IsConstantInitializer(scalar_arg->Name(), true)) {
72+
return std::nullopt;
73+
}
74+
const auto* scalar_tensor = graph_viewer.GetConstantInitializer(scalar_arg->Name());
75+
if (!scalar_tensor) {
76+
return std::nullopt;
77+
}
78+
if (scalar_tensor->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
79+
return std::nullopt;
80+
}
81+
const auto& raw_data = scalar_tensor->raw_data();
82+
if (raw_data.size() != sizeof(float) || reinterpret_cast<uintptr_t>(raw_data.data()) % alignof(float) != 0) {
83+
return std::nullopt;
84+
}
85+
return *reinterpret_cast<const float*>(raw_data.data());
86+
}
87+
88+
/// @brief Create or validate the QNN node
89+
/// @param qnn_model_wrapper QNN model wrapper
90+
/// @param node_units The node units containing the softmax and mul nodes
91+
/// @param validate Whether to validate the QNN node
92+
/// @return Status
93+
Status CreateOrValidateOnQnn(
94+
QnnModelWrapper* qnn_model_wrapper,
95+
std::array<const NodeUnit*, 2> node_units,
96+
bool validate) {
97+
const NodeUnit* mul = node_units[0];
98+
const NodeUnit* softmax = node_units[1];
99+
ORT_RETURN_IF_NOT(mul->OpType() == kOpMul,
100+
"Expected scale node to be of type Mul, got ", mul->OpType());
101+
ORT_RETURN_IF_NOT(softmax->OpType() == kOpSoftmax,
102+
"Expected softmax node to be of type Softmax, got ", softmax->OpType());
103+
size_t input_scale_index = GetMulScalarInputIndex(mul).value();
104+
size_t input_other_index = 1 - input_scale_index;
105+
const NodeUnitIODef& mul_input_other = mul->Inputs()[input_other_index];
106+
const NodeUnitIODef& softmax_output = softmax->Outputs()[0];
107+
108+
std::vector<std::string> param_tensor_names;
109+
{ // axis
110+
std::optional<uint32_t> axis = GetPositiveSoftmaxAxis(mul, softmax);
111+
if (axis.has_value()) {
112+
Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT;
113+
axis_scalar.dataType = QNN_DATATYPE_UINT_32;
114+
axis_scalar.uint32Value = axis.value();
115+
QnnParamWrapper param_wrapper(softmax->Index(),
116+
softmax->Name(),
117+
QNN_OP_SOFTMAX_PARAM_AXIS,
118+
axis_scalar);
119+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param");
120+
param_tensor_names.push_back(param_wrapper.GetParamTensorName());
121+
}
122+
}
123+
{ // beta
124+
NodeAttrHelper softmax_attr_helper(softmax->GetNode());
125+
std::optional<float> beta = softmax_attr_helper.GetFloat(QNN_OP_SOFTMAX_PARAM_BETA);
126+
float scale = ExtractScalarValueFromMul(qnn_model_wrapper->GetGraphViewer(), mul).value_or(1.0f);
127+
Qnn_Scalar_t beta_scalar = QNN_SCALAR_INIT;
128+
beta_scalar.dataType = QNN_DATATYPE_FLOAT_32;
129+
beta_scalar.floatValue = scale * beta.value_or(1.0f);
130+
QnnParamWrapper param_wrapper(softmax->Index(),
131+
softmax->Name(),
132+
QNN_OP_SOFTMAX_PARAM_BETA,
133+
beta_scalar);
134+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param");
135+
param_tensor_names.push_back(param_wrapper.GetParamTensorName());
136+
}
137+
138+
QnnTensorWrapper fused_softmax_input;
139+
QnnTensorWrapper fused_softmax_output;
140+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(mul_input_other, fused_softmax_input));
141+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(softmax_output, fused_softmax_output));
142+
143+
if (validate) {
144+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->ValidateQnnNode(softmax->Name(),
145+
QNN_OP_PACKAGE_NAME_QTI_AISW,
146+
QNN_OP_SOFTMAX,
147+
{fused_softmax_input.GetQnnTensor()},
148+
{fused_softmax_output.GetQnnTensor()},
149+
{}));
150+
} else {
151+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_input)), "Failed to add input");
152+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_output)), "Failed to add output");
153+
ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(softmax->Name(),
154+
QNN_OP_PACKAGE_NAME_QTI_AISW,
155+
QNN_OP_SOFTMAX,
156+
{mul_input_other.node_arg.Name()},
157+
{softmax_output.node_arg.Name()},
158+
std::move(param_tensor_names),
159+
validate),
160+
"Failed to add fused " + std::string(kOpSoftmax) + " node.");
161+
}
162+
return Status::OK();
163+
}
164+
165+
} // namespace
166+
167+
std::unique_ptr<IQnnNodeGroup> ScaleSoftmaxFusion::TryFusion(
168+
QnnModelWrapper& qnn_model_wrapper,
169+
const NodeUnit& mul_node_unit,
170+
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
171+
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
172+
[[maybe_unused]] const logging::Logger& logger) {
173+
if (mul_node_unit.OpType() != kOpMul || mul_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
174+
return nullptr;
175+
}
176+
// Check if the mul node has a scalar input that can fold into the softmax's beta
177+
const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
178+
std::optional<float> scalar = ExtractScalarValueFromMul(graph_viewer, &mul_node_unit);
179+
if (!scalar.has_value()) {
180+
return nullptr;
181+
}
182+
183+
// Mul node must have a single Softmax node as child
184+
const std::array<std::string_view, 1> child_op_types{kOpSoftmax};
185+
const NodeUnit* softmax = GetOnlyChildOfType(graph_viewer, mul_node_unit, child_op_types,
186+
node_to_node_unit, node_unit_to_qnn_node_group);
187+
if (softmax == nullptr) {
188+
return nullptr;
189+
}
190+
191+
std::array<const NodeUnit*, 2> node_units{&mul_node_unit, softmax};
192+
if (CreateOrValidateOnQnn(&qnn_model_wrapper, node_units, /*validate=*/true) != Status::OK()) {
193+
return nullptr;
194+
}
195+
196+
return std::make_unique<ScaleSoftmaxFusion>(node_units);
197+
}
198+
199+
Status ScaleSoftmaxFusion::IsSupported(
200+
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
201+
return CreateOrValidateOnQnn(&qnn_model_wrapper, node_units_, /*validate=*/true);
202+
}
203+
204+
Status ScaleSoftmaxFusion::AddToModelBuilder(
205+
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
206+
return CreateOrValidateOnQnn(&qnn_model_wrapper, node_units_, /*validate=*/false);
207+
}
208+
209+
} // namespace qnn
210+
} // namespace onnxruntime
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <gsl/gsl>
7+
#include <array>
8+
#include <memory>
9+
#include <unordered_map>
10+
#include <vector>
11+
12+
#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
13+
#include "core/providers/qnn/ort_api.h"
14+
15+
namespace onnxruntime {
16+
namespace qnn {
17+
18+
class QnnModelWrapper;
19+
20+
/// <summary>
21+
/// Represents a fusion of pattern: Softmax(Mul(x, scalar_scale)) => QnnSoftmax(x, beta=scalar_scale)
22+
/// </summary>
23+
class ScaleSoftmaxFusion : public IQnnNodeGroup {
24+
public:
25+
explicit ScaleSoftmaxFusion(const std::array<const NodeUnit*, 2>& node_units) : node_units_(node_units) {}
26+
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ScaleSoftmaxFusion);
27+
28+
Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override;
29+
Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override;
30+
gsl::span<const NodeUnit* const> GetNodeUnits() const override { return node_units_; }
31+
const NodeUnit* GetTargetNodeUnit() const override { return node_units_[1]; }
32+
std::string_view Type() const override { return "ScaleSoftmaxFusion"; }
33+
34+
/// <summary>
35+
/// Traverses graph to check if the given starting NodeUnit is part of a valid Softmax -> Mul sequence.
36+
/// If so, returns a IQnnNodeGroup that contains the Softmax and Mul NodeUnits.
37+
/// </summary>
38+
static std::unique_ptr<IQnnNodeGroup> TryFusion(
39+
QnnModelWrapper& qnn_model_wrapper,
40+
const NodeUnit& mul_node_unit,
41+
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
42+
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
43+
const logging::Logger& logger);
44+
45+
private:
46+
std::array<const NodeUnit*, 2> node_units_;
47+
};
48+
49+
} // namespace qnn
50+
} // namespace onnxruntime

0 commit comments

Comments
 (0)