Skip to content

Commit f9739c2

Browse files
authored
[QNN EP] Fuse scale into softmax (#24809)
QNN [Softmax op defines pre-scale (`beta`)](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/MasterOpDef.html#softmax) that we can fold constant scalar multiply into it.
1 parent 801006d commit f9739c2

File tree

6 files changed

+431
-1
lines changed

6 files changed

+431
-1
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,7 @@ endif()
724724
# or reduced op builds.
725725
if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD)
726726
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/*)
727+
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/qnn_node_group/*)
727728
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_qnn)
728729
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_qnn)
729730
if(NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)

onnxruntime/core/optimizer/bias_softmax_fusion.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node
135135
new_axis = (int)HandleNegativeAxis(axis, rank);
136136

137137
// The axis attribute for Softmax in OpSet-11 and OpSet-13 are different.
138-
// Details in function documentatin.
138+
// Details in function documentation.
139139
if (is_since_opset_13 && new_axis != rank - 1) return false;
140140

141141
int singlebatch_rank = rank - new_axis;

onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.h"
1616
#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
1717
#include "core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.h"
18+
#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h"
1819
#include "core/providers/qnn/builder/qnn_utils.h"
1920
#include "core/providers/qnn/ort_api.h"
2021

@@ -90,6 +91,7 @@ static std::unique_ptr<IQnnNodeGroup> TryQnnFusions(
9091
{"DequantizeLinear", DQQFusion::TryFusion},
9192
{"HardSigmoid", HardSigmoidMulFusion::TryFusion},
9293
{"Gemm", ReshapeGemmFusion::TryFusion},
94+
{"Mul", ScaleSoftmaxFusion::TryFusion},
9395
};
9496

9597
// For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes).
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h"
5+
6+
#include <gsl/gsl>
7+
#include <optional>
8+
#include <utility>
9+
#include <string>
10+
#include <array>
11+
#include <memory>
12+
#include <unordered_map>
13+
#include <vector>
14+
15+
#include "core/providers/qnn/builder/qnn_utils.h"
16+
#include "core/providers/qnn/builder/op_builder_factory.h"
17+
#include "core/providers/qnn/builder/qnn_node_group/utils.h"
18+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
19+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
20+
21+
namespace onnxruntime {
22+
namespace qnn {
23+
namespace {
24+
25+
constexpr char kOpMul[] = "Mul";
26+
constexpr char kOpSoftmax[] = "Softmax";
27+
28+
/// @brief Get the index of the scalar input in the mul node
29+
/// @param mul Multiply node unit
30+
/// @return The index of the scalar input (0 or 1) if found, otherwise std::nullopt
31+
std::optional<size_t> GetMulScalarInputIndex(const NodeUnit* mul) {
32+
const NodeArg* mul_y = mul->GetNode().InputDefs()[1];
33+
const NodeArg* mul_x = mul->GetNode().InputDefs()[0];
34+
auto y_shape_proto = mul_y->Shape();
35+
auto x_shape_proto = mul_x->Shape();
36+
bool is_y_scalar = false;
37+
if (y_shape_proto != nullptr) {
38+
auto y_shape = utils::GetTensorProtoShape(*y_shape_proto);
39+
is_y_scalar = y_shape.NumDimensions() == 0;
40+
}
41+
bool is_x_scalar = false;
42+
if (x_shape_proto != nullptr) {
43+
auto x_shape = utils::GetTensorProtoShape(*x_shape_proto);
44+
is_x_scalar = x_shape.NumDimensions() == 0;
45+
}
46+
if (is_y_scalar) {
47+
return 1U;
48+
} else if (is_x_scalar) {
49+
return 0U;
50+
}
51+
return std::nullopt;
52+
}
53+
54+
/// @brief Get the axis for softmax
55+
/// @param mul Multiply node unit
56+
/// @param softmax Softmax node unit
57+
/// @return The axis for softmax
58+
std::optional<uint32_t> GetPositiveSoftmaxAxis(const NodeUnit* mul, const NodeUnit* softmax) {
59+
NodeAttrHelper softmax_attr_helper(softmax->GetNode());
60+
std::optional<int64_t> param_axis = softmax_attr_helper.GetInt64(QNN_OP_SOFTMAX_PARAM_AXIS);
61+
if (!param_axis.has_value()) {
62+
return std::nullopt;
63+
}
64+
int64_t axis_value = param_axis.value();
65+
if (axis_value < 0) {
66+
size_t input_scale_index = GetMulScalarInputIndex(mul).value();
67+
size_t input_other_index = 1U - input_scale_index;
68+
int rank = mul->GetNode().InputDefs()[input_other_index]->Shape()->dim_size();
69+
axis_value += static_cast<int64_t>(rank);
70+
}
71+
return static_cast<uint32_t>(axis_value);
72+
}
73+
74+
/// @brief Identify scalar input from mul node if present
75+
/// @param mul Multiply node unit
76+
/// @return The scalar input float value if found, otherwise std::nullopt
77+
std::optional<float> ExtractScalarValueFromMul(const GraphViewer& graph_viewer, const NodeUnit* mul) {
78+
std::optional<size_t> input_scale_index = GetMulScalarInputIndex(mul);
79+
if (!input_scale_index.has_value()) {
80+
return std::nullopt;
81+
}
82+
const NodeArg* scalar_arg = mul->GetNode().InputDefs()[input_scale_index.value()];
83+
if (!graph_viewer.IsConstantInitializer(scalar_arg->Name(), true)) {
84+
return std::nullopt;
85+
}
86+
const auto* scalar_tensor = graph_viewer.GetConstantInitializer(scalar_arg->Name());
87+
if (!scalar_tensor) {
88+
return std::nullopt;
89+
}
90+
if (scalar_tensor->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
91+
return std::nullopt;
92+
}
93+
const auto& raw_data = scalar_tensor->raw_data();
94+
if (raw_data.size() != sizeof(float) || reinterpret_cast<uintptr_t>(raw_data.data()) % alignof(float) != 0) {
95+
return std::nullopt;
96+
}
97+
return *reinterpret_cast<const float*>(raw_data.data());
98+
}
99+
100+
/// @brief Create or validate the QNN node
101+
/// @param qnn_model_wrapper QNN model wrapper
102+
/// @param node_units The node units containing the softmax and mul nodes
103+
/// @param validate Whether to validate the QNN node
104+
/// @return Status
105+
Status CreateOrValidateOnQnn(
106+
QnnModelWrapper* qnn_model_wrapper,
107+
gsl::span<const NodeUnit* const> node_units,
108+
bool validate) {
109+
const NodeUnit* mul = node_units[0];
110+
const NodeUnit* softmax = node_units[1];
111+
ORT_RETURN_IF_NOT(mul->OpType() == kOpMul,
112+
"Expected scale node to be of type Mul, got ", mul->OpType());
113+
ORT_RETURN_IF_NOT(softmax->OpType() == kOpSoftmax,
114+
"Expected softmax node to be of type Softmax, got ", softmax->OpType());
115+
size_t input_scale_index = GetMulScalarInputIndex(mul).value();
116+
size_t input_other_index = 1U - input_scale_index;
117+
const NodeUnitIODef& mul_input_other = mul->Inputs()[input_other_index];
118+
const NodeUnitIODef& softmax_output = softmax->Outputs()[0];
119+
120+
std::vector<std::string> param_tensor_names;
121+
{ // axis
122+
std::optional<uint32_t> axis = GetPositiveSoftmaxAxis(mul, softmax);
123+
if (axis.has_value()) {
124+
Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT;
125+
axis_scalar.dataType = QNN_DATATYPE_UINT_32;
126+
axis_scalar.uint32Value = axis.value();
127+
QnnParamWrapper param_wrapper(softmax->Index(),
128+
softmax->Name(),
129+
QNN_OP_SOFTMAX_PARAM_AXIS,
130+
axis_scalar);
131+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param");
132+
param_tensor_names.push_back(param_wrapper.GetParamTensorName());
133+
}
134+
}
135+
{ // beta
136+
NodeAttrHelper softmax_attr_helper(softmax->GetNode());
137+
std::optional<float> beta = softmax_attr_helper.GetFloat(QNN_OP_SOFTMAX_PARAM_BETA);
138+
float scale = ExtractScalarValueFromMul(qnn_model_wrapper->GetGraphViewer(), mul).value_or(1.0f);
139+
Qnn_Scalar_t beta_scalar = QNN_SCALAR_INIT;
140+
beta_scalar.dataType = QNN_DATATYPE_FLOAT_32;
141+
beta_scalar.floatValue = scale * beta.value_or(1.0f);
142+
QnnParamWrapper param_wrapper(softmax->Index(),
143+
softmax->Name(),
144+
QNN_OP_SOFTMAX_PARAM_BETA,
145+
beta_scalar);
146+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param");
147+
param_tensor_names.push_back(param_wrapper.GetParamTensorName());
148+
}
149+
150+
QnnTensorWrapper fused_softmax_input;
151+
QnnTensorWrapper fused_softmax_output;
152+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(mul_input_other, fused_softmax_input));
153+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(softmax_output, fused_softmax_output));
154+
155+
if (validate) {
156+
ORT_RETURN_IF_ERROR(qnn_model_wrapper->ValidateQnnNode(softmax->Name(),
157+
QNN_OP_PACKAGE_NAME_QTI_AISW,
158+
QNN_OP_SOFTMAX,
159+
{fused_softmax_input.GetQnnTensor()},
160+
{fused_softmax_output.GetQnnTensor()},
161+
{}));
162+
} else {
163+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_input)), "Failed to add input");
164+
ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_output)), "Failed to add output");
165+
ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(softmax->Name(),
166+
QNN_OP_PACKAGE_NAME_QTI_AISW,
167+
QNN_OP_SOFTMAX,
168+
{mul_input_other.node_arg.Name()},
169+
{softmax_output.node_arg.Name()},
170+
std::move(param_tensor_names),
171+
validate),
172+
"Failed to add fused " + std::string(kOpSoftmax) + " node.");
173+
}
174+
return Status::OK();
175+
}
176+
177+
} // namespace
178+
179+
std::unique_ptr<IQnnNodeGroup> ScaleSoftmaxFusion::TryFusion(
180+
QnnModelWrapper& qnn_model_wrapper,
181+
const NodeUnit& mul_node_unit,
182+
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
183+
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
184+
[[maybe_unused]] const logging::Logger& logger) {
185+
if (mul_node_unit.OpType() != kOpMul || mul_node_unit.UnitType() != NodeUnit::Type::SingleNode) {
186+
return nullptr;
187+
}
188+
// Check if the mul node has a scalar input that can fold into the softmax's beta
189+
const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer();
190+
std::optional<float> scalar = ExtractScalarValueFromMul(graph_viewer, &mul_node_unit);
191+
if (!scalar.has_value()) {
192+
return nullptr;
193+
}
194+
195+
// Mul node must have a single Softmax node as child
196+
const std::array<std::string_view, 1> child_op_types{kOpSoftmax};
197+
const NodeUnit* softmax = GetOnlyChildOfType(graph_viewer, mul_node_unit, child_op_types,
198+
node_to_node_unit, node_unit_to_qnn_node_group);
199+
if (softmax == nullptr) {
200+
return nullptr;
201+
}
202+
203+
std::array<const NodeUnit*, 2> node_unit_array{&mul_node_unit, softmax};
204+
auto node_units = gsl::make_span<const NodeUnit*>(node_unit_array.data(), 2);
205+
if (CreateOrValidateOnQnn(&qnn_model_wrapper, node_units, /*validate=*/true) != Status::OK()) {
206+
return nullptr;
207+
}
208+
return std::make_unique<ScaleSoftmaxFusion>(node_units);
209+
}
210+
211+
gsl::span<const NodeUnit* const> ScaleSoftmaxFusion::GetNodeUnits() const {
212+
return gsl::span<const NodeUnit* const>{node_units_.data(), node_units_.size()};
213+
}
214+
215+
Status ScaleSoftmaxFusion::IsSupported(
216+
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
217+
return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/true);
218+
}
219+
220+
Status ScaleSoftmaxFusion::AddToModelBuilder(
221+
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
222+
return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/false);
223+
}
224+
225+
} // namespace qnn
226+
} // namespace onnxruntime
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <gsl/gsl>
7+
#include <array>
8+
#include <memory>
9+
#include <unordered_map>
10+
#include <vector>
11+
12+
#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h"
13+
#include "core/providers/qnn/ort_api.h"
14+
15+
namespace onnxruntime {
16+
namespace qnn {
17+
18+
class QnnModelWrapper;
19+
20+
/// <summary>
21+
/// Represents a fusion of pattern: Softmax(Mul(x, scalar_scale)) => QnnSoftmax(x, beta=scalar_scale)
22+
/// </summary>
23+
class ScaleSoftmaxFusion : public IQnnNodeGroup {
24+
public:
25+
explicit ScaleSoftmaxFusion(gsl::span<const NodeUnit* const> node_units) {
26+
ORT_ENFORCE(node_units.size() == 2, "Pattern expect exactly 2 NodeUnits.");
27+
node_units_[0] = node_units[0];
28+
node_units_[1] = node_units[1];
29+
}
30+
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ScaleSoftmaxFusion);
31+
32+
Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override;
33+
Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override;
34+
gsl::span<const NodeUnit* const> GetNodeUnits() const override;
35+
const NodeUnit* GetTargetNodeUnit() const override { return node_units_[1]; }
36+
std::string_view Type() const override { return "ScaleSoftmaxFusion"; }
37+
38+
/// <summary>
39+
/// Traverses graph to check if the given starting NodeUnit is part of a valid Softmax -> Mul sequence.
40+
/// If so, returns a IQnnNodeGroup that contains the Softmax and Mul NodeUnits.
41+
/// </summary>
42+
static std::unique_ptr<IQnnNodeGroup> TryFusion(
43+
QnnModelWrapper& qnn_model_wrapper,
44+
const NodeUnit& mul_node_unit,
45+
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
46+
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
47+
const logging::Logger& logger);
48+
49+
private:
50+
std::array<const NodeUnit*, 2> node_units_;
51+
};
52+
53+
} // namespace qnn
54+
} // namespace onnxruntime

0 commit comments

Comments
 (0)