Skip to content

Commit 3c449ee

Browse files
committed
Address review feedback
1 parent ddaa6b5 commit 3c449ee

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
#include <utility>
1010
#include <string>
1111
#include <array>
12-
#include <vector>
12+
#include <memory>
13+
#include <unordered_map>
1314

1415
#include "core/providers/qnn/builder/qnn_utils.h"
1516
#include "core/providers/qnn/builder/op_builder_factory.h"
@@ -37,11 +38,12 @@ std::optional<size_t> GetMulScalarInputIndex(const NodeUnit* mul) {
3738
if (!is_x_scalar && !is_y_scalar) {
3839
return std::nullopt;
3940
}
40-
return is_y_scalar ? 1 : 0;
41+
return is_y_scalar ? 1U : 0U;
4142
}
4243

4344
/// @brief Get the axis for softmax
44-
/// @param node_units The node units containing the softmax and mul nodes
45+
/// @param mul Multiply node unit
46+
/// @param softmax Softmax node unit
4547
/// @return The axis for softmax
4648
std::optional<uint32_t> GetPositiveSoftmaxAxis(const NodeUnit* mul, const NodeUnit* softmax) {
4749
NodeAttrHelper softmax_attr_helper(softmax->GetNode());
@@ -52,7 +54,7 @@ std::optional<uint32_t> GetPositiveSoftmaxAxis(const NodeUnit* mul, const NodeUn
5254
int64_t axis_value = param_axis.value();
5355
if (axis_value < 0) {
5456
size_t input_scale_index = GetMulScalarInputIndex(mul).value();
55-
size_t input_other_index = 1 - input_scale_index;
57+
size_t input_other_index = 1U - input_scale_index;
5658
int rank = mul->GetNode().InputDefs()[input_other_index]->Shape()->dim_size();
5759
axis_value += static_cast<int64_t>(rank);
5860
}
@@ -92,7 +94,7 @@ std::optional<float> ExtractScalarValueFromMul(const GraphViewer& graph_viewer,
9294
/// @return Status
9395
Status CreateOrValidateOnQnn(
9496
QnnModelWrapper* qnn_model_wrapper,
95-
std::array<const NodeUnit*, 2> node_units,
97+
gsl::span<const NodeUnit* const> node_units,
9698
bool validate) {
9799
const NodeUnit* mul = node_units[0];
98100
const NodeUnit* softmax = node_units[1];
@@ -101,7 +103,7 @@ Status CreateOrValidateOnQnn(
101103
ORT_RETURN_IF_NOT(softmax->OpType() == kOpSoftmax,
102104
"Expected softmax node to be of type Softmax, got ", softmax->OpType());
103105
size_t input_scale_index = GetMulScalarInputIndex(mul).value();
104-
size_t input_other_index = 1 - input_scale_index;
106+
size_t input_other_index = 1U - input_scale_index;
105107
const NodeUnitIODef& mul_input_other = mul->Inputs()[input_other_index];
106108
const NodeUnitIODef& softmax_output = softmax->Outputs()[0];
107109

@@ -188,22 +190,26 @@ std::unique_ptr<IQnnNodeGroup> ScaleSoftmaxFusion::TryFusion(
188190
return nullptr;
189191
}
190192

191-
std::array<const NodeUnit*, 2> node_units{&mul_node_unit, softmax};
193+
std::array<const NodeUnit*, 2> node_unit_array{&mul_node_unit, softmax};
194+
auto node_units = gsl::make_span<const NodeUnit*>(node_unit_array.data(), 2);
192195
if (CreateOrValidateOnQnn(&qnn_model_wrapper, node_units, /*validate=*/true) != Status::OK()) {
193196
return nullptr;
194197
}
195-
196198
return std::make_unique<ScaleSoftmaxFusion>(node_units);
197199
}
198200

201+
gsl::span<const NodeUnit* const> ScaleSoftmaxFusion::GetNodeUnits() const {
202+
return gsl::span<const NodeUnit* const>{node_units_.data(), node_units_.size()};
203+
}
204+
199205
Status ScaleSoftmaxFusion::IsSupported(
200206
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
201-
return CreateOrValidateOnQnn(&qnn_model_wrapper, node_units_, /*validate=*/true);
207+
return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/true);
202208
}
203209

204210
Status ScaleSoftmaxFusion::AddToModelBuilder(
205211
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
206-
return CreateOrValidateOnQnn(&qnn_model_wrapper, node_units_, /*validate=*/false);
212+
return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/false);
207213
}
208214

209215
} // namespace qnn

onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ class QnnModelWrapper;
2222
/// </summary>
2323
class ScaleSoftmaxFusion : public IQnnNodeGroup {
2424
public:
25-
explicit ScaleSoftmaxFusion(const std::array<const NodeUnit*, 2>& node_units) : node_units_(node_units) {}
25+
explicit ScaleSoftmaxFusion(gsl::span<const NodeUnit* const> node_units) {
26+
ORT_ENFORCE(node_units.size() == 2, "Pattern expect exactly 2 NodeUnits.");
27+
node_units_[0] = node_units[0];
28+
node_units_[1] = node_units[1];
29+
}
2630
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ScaleSoftmaxFusion);
2731

2832
Status IsSupported(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override;
2933
Status AddToModelBuilder(QnnModelWrapper& qnn_model_wrapper, const logging::Logger& logger) const override;
30-
gsl::span<const NodeUnit* const> GetNodeUnits() const override { return node_units_; }
34+
gsl::span<const NodeUnit* const> GetNodeUnits() const override;
3135
const NodeUnit* GetTargetNodeUnit() const override { return node_units_[1]; }
3236
std::string_view Type() const override { return "ScaleSoftmaxFusion"; }
3337

onnxruntime/test/providers/qnn/qnn_node_group/scale_softmax_fusion_test.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarInitializer) {
7171
/*opset_version=*/13,
7272
/*expected_ep_assignment=*/ExpectedEPNodeAssignment::All,
7373
/*fp32_abs_err=*/1e-2f);
74-
75-
7674
}
7775

7876
TEST_F(QnnHTPBackendTests, ScaleSoftmaxFusionScalarConstant) {

0 commit comments

Comments
 (0)