99#include < utility>
1010#include < string>
1111#include < array>
12- #include < vector>
12+ #include < memory>
13+ #include < unordered_map>
1314
1415#include " core/providers/qnn/builder/qnn_utils.h"
1516#include " core/providers/qnn/builder/op_builder_factory.h"
@@ -37,11 +38,12 @@ std::optional<size_t> GetMulScalarInputIndex(const NodeUnit* mul) {
3738 if (!is_x_scalar && !is_y_scalar) {
3839 return std::nullopt ;
3940 }
40- return is_y_scalar ? 1 : 0 ;
41+ return is_y_scalar ? 1U : 0U ;
4142}
4243
4344// / @brief Get the axis for softmax
44- // / @param node_units The node units containing the softmax and mul nodes
45+ // / @param mul Multiply node unit
46+ // / @param softmax Softmax node unit
4547// / @return The axis for softmax
4648std::optional<uint32_t > GetPositiveSoftmaxAxis (const NodeUnit* mul, const NodeUnit* softmax) {
4749 NodeAttrHelper softmax_attr_helper (softmax->GetNode ());
@@ -52,7 +54,7 @@ std::optional<uint32_t> GetPositiveSoftmaxAxis(const NodeUnit* mul, const NodeUn
5254 int64_t axis_value = param_axis.value ();
5355 if (axis_value < 0 ) {
5456 size_t input_scale_index = GetMulScalarInputIndex (mul).value ();
55- size_t input_other_index = 1 - input_scale_index;
57+ size_t input_other_index = 1U - input_scale_index;
5658 int rank = mul->GetNode ().InputDefs ()[input_other_index]->Shape ()->dim_size ();
5759 axis_value += static_cast <int64_t >(rank);
5860 }
@@ -92,7 +94,7 @@ std::optional<float> ExtractScalarValueFromMul(const GraphViewer& graph_viewer,
9294// / @return Status
9395Status CreateOrValidateOnQnn (
9496 QnnModelWrapper* qnn_model_wrapper,
95- std::array <const NodeUnit*, 2 > node_units,
97+ gsl::span <const NodeUnit* const > node_units,
9698 bool validate) {
9799 const NodeUnit* mul = node_units[0 ];
98100 const NodeUnit* softmax = node_units[1 ];
@@ -101,7 +103,7 @@ Status CreateOrValidateOnQnn(
101103 ORT_RETURN_IF_NOT (softmax->OpType () == kOpSoftmax ,
102104 " Expected softmax node to be of type Softmax, got " , softmax->OpType ());
103105 size_t input_scale_index = GetMulScalarInputIndex (mul).value ();
104- size_t input_other_index = 1 - input_scale_index;
106+ size_t input_other_index = 1U - input_scale_index;
105107 const NodeUnitIODef& mul_input_other = mul->Inputs ()[input_other_index];
106108 const NodeUnitIODef& softmax_output = softmax->Outputs ()[0 ];
107109
@@ -188,22 +190,26 @@ std::unique_ptr<IQnnNodeGroup> ScaleSoftmaxFusion::TryFusion(
188190 return nullptr ;
189191 }
190192
191- std::array<const NodeUnit*, 2 > node_units{&mul_node_unit, softmax};
193+ std::array<const NodeUnit*, 2 > node_unit_array{&mul_node_unit, softmax};
194+ auto node_units = gsl::make_span<const NodeUnit*>(node_unit_array.data (), 2 );
192195 if (CreateOrValidateOnQnn (&qnn_model_wrapper, node_units, /* validate=*/ true ) != Status::OK ()) {
193196 return nullptr ;
194197 }
195-
196198 return std::make_unique<ScaleSoftmaxFusion>(node_units);
197199}
198200
201+ gsl::span<const NodeUnit* const > ScaleSoftmaxFusion::GetNodeUnits () const {
202+ return gsl::span<const NodeUnit* const >{node_units_.data (), node_units_.size ()};
203+ }
204+
199205Status ScaleSoftmaxFusion::IsSupported (
200206 QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
201- return CreateOrValidateOnQnn (&qnn_model_wrapper, node_units_ , /* validate=*/ true );
207+ return CreateOrValidateOnQnn (&qnn_model_wrapper, GetNodeUnits () , /* validate=*/ true );
202208}
203209
204210Status ScaleSoftmaxFusion::AddToModelBuilder (
205211 QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
206- return CreateOrValidateOnQnn (&qnn_model_wrapper, node_units_ , /* validate=*/ false );
212+ return CreateOrValidateOnQnn (&qnn_model_wrapper, GetNodeUnits () , /* validate=*/ false );
207213}
208214
209215} // namespace qnn
0 commit comments