@@ -29,16 +29,26 @@ constexpr char kOpSoftmax[] = "Softmax";
2929// / @param mul Multiply node unit
3030// / @return The index of the scalar input (0 or 1) if found, otherwise std::nullopt
3131std::optional<size_t > GetMulScalarInputIndex (const NodeUnit* mul) {
32- const NodeArg* mul_x = mul->GetNode ().InputDefs ()[0 ];
3332 const NodeArg* mul_y = mul->GetNode ().InputDefs ()[1 ];
34- auto mul_x_shape = utils::GetTensorProtoShape (*mul_x->Shape ());
35- auto mul_y_shape = utils::GetTensorProtoShape (*mul_y->Shape ());
36- bool is_x_scalar = mul_x_shape.NumDimensions () == 0 ;
37- bool is_y_scalar = mul_y_shape.NumDimensions () == 0 ;
38- if (!is_x_scalar && !is_y_scalar) {
39- return std::nullopt ;
33+ const NodeArg* mul_x = mul->GetNode ().InputDefs ()[0 ];
34+ auto y_shape_proto = mul_y->Shape ();
35+ auto x_shape_proto = mul_x->Shape ();
36+ bool is_y_scalar = false ;
37+ if (y_shape_proto != nullptr ) {
38+ auto y_shape = utils::GetTensorProtoShape (*y_shape_proto);
39+ is_y_scalar = y_shape.NumDimensions () == 0 ;
40+ }
41+ bool is_x_scalar = false ;
42+ if (x_shape_proto != nullptr ) {
43+ auto x_shape = utils::GetTensorProtoShape (*x_shape_proto);
44+ is_x_scalar = x_shape.NumDimensions () == 0 ;
45+ }
46+ if (is_y_scalar) {
47+ return 1U ;
48+ } else if (is_x_scalar) {
49+ return 0U ;
4050 }
41- return is_y_scalar ? 1U : 0U ;
51+ return std:: nullopt ;
4252}
4353
4454// / @brief Get the axis for softmax
0 commit comments