Skip to content

Commit 09a403d

Browse files
committed
Fix CI failure
1 parent c7ad7df commit 09a403d

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,26 @@ constexpr char kOpSoftmax[] = "Softmax";
2929
/// @param mul Multiply node unit
3030
/// @return The index of the scalar input (0 or 1) if found, otherwise std::nullopt
3131
std::optional<size_t> GetMulScalarInputIndex(const NodeUnit* mul) {
32-
const NodeArg* mul_x = mul->GetNode().InputDefs()[0];
3332
const NodeArg* mul_y = mul->GetNode().InputDefs()[1];
34-
auto mul_x_shape = utils::GetTensorProtoShape(*mul_x->Shape());
35-
auto mul_y_shape = utils::GetTensorProtoShape(*mul_y->Shape());
36-
bool is_x_scalar = mul_x_shape.NumDimensions() == 0;
37-
bool is_y_scalar = mul_y_shape.NumDimensions() == 0;
38-
if (!is_x_scalar && !is_y_scalar) {
39-
return std::nullopt;
33+
const NodeArg* mul_x = mul->GetNode().InputDefs()[0];
34+
auto y_shape_proto = mul_y->Shape();
35+
auto x_shape_proto = mul_x->Shape();
36+
bool is_y_scalar = false;
37+
if (y_shape_proto != nullptr) {
38+
auto y_shape = utils::GetTensorProtoShape(*y_shape_proto);
39+
is_y_scalar = y_shape.NumDimensions() == 0;
40+
}
41+
bool is_x_scalar = false;
42+
if (x_shape_proto != nullptr) {
43+
auto x_shape = utils::GetTensorProtoShape(*x_shape_proto);
44+
is_x_scalar = x_shape.NumDimensions() == 0;
45+
}
46+
if (is_y_scalar) {
47+
return 1U;
48+
} else if (is_x_scalar) {
49+
return 0U;
4050
}
41-
return is_y_scalar ? 1U : 0U;
51+
return std::nullopt;
4252
}
4353

4454
/// @brief Get the axis for softmax

0 commit comments

Comments
 (0)