9
9
#include < utility>
10
10
#include < string>
11
11
#include < array>
12
- #include < vector>
12
+ #include < memory>
13
+ #include < unordered_map>
13
14
14
15
#include " core/providers/qnn/builder/qnn_utils.h"
15
16
#include " core/providers/qnn/builder/op_builder_factory.h"
@@ -37,11 +38,12 @@ std::optional<size_t> GetMulScalarInputIndex(const NodeUnit* mul) {
37
38
if (!is_x_scalar && !is_y_scalar) {
38
39
return std::nullopt;
39
40
}
40
- return is_y_scalar ? 1 : 0 ;
41
+ return is_y_scalar ? 1U : 0U ;
41
42
}
42
43
43
44
// / @brief Get the axis for softmax
44
- // / @param node_units The node units containing the softmax and mul nodes
45
+ // / @param mul Multiply node unit
46
+ // / @param softmax Softmax node unit
45
47
// / @return The axis for softmax
46
48
std::optional<uint32_t > GetPositiveSoftmaxAxis (const NodeUnit* mul, const NodeUnit* softmax) {
47
49
NodeAttrHelper softmax_attr_helper (softmax->GetNode ());
@@ -52,7 +54,7 @@ std::optional<uint32_t> GetPositiveSoftmaxAxis(const NodeUnit* mul, const NodeUn
52
54
int64_t axis_value = param_axis.value ();
53
55
if (axis_value < 0 ) {
54
56
size_t input_scale_index = GetMulScalarInputIndex (mul).value ();
55
- size_t input_other_index = 1 - input_scale_index;
57
+ size_t input_other_index = 1U - input_scale_index;
56
58
int rank = mul->GetNode ().InputDefs ()[input_other_index]->Shape ()->dim_size ();
57
59
axis_value += static_cast <int64_t >(rank);
58
60
}
@@ -92,7 +94,7 @@ std::optional<float> ExtractScalarValueFromMul(const GraphViewer& graph_viewer,
92
94
// / @return Status
93
95
Status CreateOrValidateOnQnn (
94
96
QnnModelWrapper* qnn_model_wrapper,
95
- std::array <const NodeUnit*, 2 > node_units,
97
+ gsl::span <const NodeUnit* const > node_units,
96
98
bool validate) {
97
99
const NodeUnit* mul = node_units[0 ];
98
100
const NodeUnit* softmax = node_units[1 ];
@@ -101,7 +103,7 @@ Status CreateOrValidateOnQnn(
101
103
ORT_RETURN_IF_NOT (softmax->OpType () == kOpSoftmax ,
102
104
" Expected softmax node to be of type Softmax, got " , softmax->OpType ());
103
105
size_t input_scale_index = GetMulScalarInputIndex (mul).value ();
104
- size_t input_other_index = 1 - input_scale_index;
106
+ size_t input_other_index = 1U - input_scale_index;
105
107
const NodeUnitIODef& mul_input_other = mul->Inputs ()[input_other_index];
106
108
const NodeUnitIODef& softmax_output = softmax->Outputs ()[0 ];
107
109
@@ -188,22 +190,26 @@ std::unique_ptr<IQnnNodeGroup> ScaleSoftmaxFusion::TryFusion(
188
190
return nullptr ;
189
191
}
190
192
191
- std::array<const NodeUnit*, 2 > node_units{&mul_node_unit, softmax};
193
+ std::array<const NodeUnit*, 2 > node_unit_array{&mul_node_unit, softmax};
194
+ auto node_units = gsl::make_span<const NodeUnit*>(node_unit_array.data (), 2 );
192
195
if (CreateOrValidateOnQnn (&qnn_model_wrapper, node_units, /* validate=*/ true ) != Status::OK ()) {
193
196
return nullptr ;
194
197
}
195
-
196
198
return std::make_unique<ScaleSoftmaxFusion>(node_units);
197
199
}
198
200
201
+ gsl::span<const NodeUnit* const > ScaleSoftmaxFusion::GetNodeUnits () const {
202
+ return gsl::span<const NodeUnit* const >{node_units_.data (), node_units_.size ()};
203
+ }
204
+
199
205
Status ScaleSoftmaxFusion::IsSupported (
200
206
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
201
- return CreateOrValidateOnQnn (&qnn_model_wrapper, node_units_ , /* validate=*/ true );
207
+ return CreateOrValidateOnQnn (&qnn_model_wrapper, GetNodeUnits () , /* validate=*/ true );
202
208
}
203
209
204
210
Status ScaleSoftmaxFusion::AddToModelBuilder (
205
211
QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const {
206
- return CreateOrValidateOnQnn (&qnn_model_wrapper, node_units_ , /* validate=*/ false );
212
+ return CreateOrValidateOnQnn (&qnn_model_wrapper, GetNodeUnits () , /* validate=*/ false );
207
213
}
208
214
209
215
} // namespace qnn
0 commit comments