@@ -29,16 +29,26 @@ constexpr char kOpSoftmax[] = "Softmax";
29
29
// / @param mul Multiply node unit
30
30
// / @return The index of the scalar input (0 or 1) if found, otherwise std::nullopt
31
31
std::optional<size_t > GetMulScalarInputIndex (const NodeUnit* mul) {
32
- const NodeArg* mul_x = mul->GetNode ().InputDefs ()[0 ];
33
32
const NodeArg* mul_y = mul->GetNode ().InputDefs ()[1 ];
34
- auto mul_x_shape = utils::GetTensorProtoShape (*mul_x->Shape ());
35
- auto mul_y_shape = utils::GetTensorProtoShape (*mul_y->Shape ());
36
- bool is_x_scalar = mul_x_shape.NumDimensions () == 0 ;
37
- bool is_y_scalar = mul_y_shape.NumDimensions () == 0 ;
38
- if (!is_x_scalar && !is_y_scalar) {
39
- return std::nullopt;
33
+ const NodeArg* mul_x = mul->GetNode ().InputDefs ()[0 ];
34
+ auto y_shape_proto = mul_y->Shape ();
35
+ auto x_shape_proto = mul_x->Shape ();
36
+ bool is_y_scalar = false ;
37
+ if (y_shape_proto != nullptr ) {
38
+ auto y_shape = utils::GetTensorProtoShape (*y_shape_proto);
39
+ is_y_scalar = y_shape.NumDimensions () == 0 ;
40
+ }
41
+ bool is_x_scalar = false ;
42
+ if (x_shape_proto != nullptr ) {
43
+ auto x_shape = utils::GetTensorProtoShape (*x_shape_proto);
44
+ is_x_scalar = x_shape.NumDimensions () == 0 ;
45
+ }
46
+ if (is_y_scalar) {
47
+ return 1U ;
48
+ } else if (is_x_scalar) {
49
+ return 0U ;
40
50
}
41
- return is_y_scalar ? 1U : 0U ;
51
+ return std::nullopt ;
42
52
}
43
53
44
54
// / @brief Get the axis for softmax
0 commit comments