|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h" |
| 5 | + |
| 6 | +#include <gsl/gsl> |
| 7 | +#include <optional> |
| 8 | +#include <utility> |
| 9 | +#include <string> |
| 10 | +#include <array> |
| 11 | +#include <memory> |
| 12 | +#include <unordered_map> |
| 13 | +#include <vector> |
| 14 | + |
| 15 | +#include "core/providers/qnn/builder/qnn_utils.h" |
| 16 | +#include "core/providers/qnn/builder/op_builder_factory.h" |
| 17 | +#include "core/providers/qnn/builder/qnn_node_group/utils.h" |
| 18 | +#include "core/providers/qnn/builder/qnn_model_wrapper.h" |
| 19 | +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" |
| 20 | + |
| 21 | +namespace onnxruntime { |
| 22 | +namespace qnn { |
| 23 | +namespace { |
| 24 | + |
| 25 | +constexpr char kOpMul[] = "Mul"; |
| 26 | +constexpr char kOpSoftmax[] = "Softmax"; |
| 27 | + |
| 28 | +/// @brief Get the index of the scalar input in the mul node |
| 29 | +/// @param mul Multiply node unit |
| 30 | +/// @return The index of the scalar input (0 or 1) if found, otherwise std::nullopt |
| 31 | +std::optional<size_t> GetMulScalarInputIndex(const NodeUnit* mul) { |
| 32 | + const NodeArg* mul_y = mul->GetNode().InputDefs()[1]; |
| 33 | + const NodeArg* mul_x = mul->GetNode().InputDefs()[0]; |
| 34 | + auto y_shape_proto = mul_y->Shape(); |
| 35 | + auto x_shape_proto = mul_x->Shape(); |
| 36 | + bool is_y_scalar = false; |
| 37 | + if (y_shape_proto != nullptr) { |
| 38 | + auto y_shape = utils::GetTensorProtoShape(*y_shape_proto); |
| 39 | + is_y_scalar = y_shape.NumDimensions() == 0; |
| 40 | + } |
| 41 | + bool is_x_scalar = false; |
| 42 | + if (x_shape_proto != nullptr) { |
| 43 | + auto x_shape = utils::GetTensorProtoShape(*x_shape_proto); |
| 44 | + is_x_scalar = x_shape.NumDimensions() == 0; |
| 45 | + } |
| 46 | + if (is_y_scalar) { |
| 47 | + return 1U; |
| 48 | + } else if (is_x_scalar) { |
| 49 | + return 0U; |
| 50 | + } |
| 51 | + return std::nullopt; |
| 52 | +} |
| 53 | + |
| 54 | +/// @brief Get the axis for softmax |
| 55 | +/// @param mul Multiply node unit |
| 56 | +/// @param softmax Softmax node unit |
| 57 | +/// @return The axis for softmax |
| 58 | +std::optional<uint32_t> GetPositiveSoftmaxAxis(const NodeUnit* mul, const NodeUnit* softmax) { |
| 59 | + NodeAttrHelper softmax_attr_helper(softmax->GetNode()); |
| 60 | + std::optional<int64_t> param_axis = softmax_attr_helper.GetInt64(QNN_OP_SOFTMAX_PARAM_AXIS); |
| 61 | + if (!param_axis.has_value()) { |
| 62 | + return std::nullopt; |
| 63 | + } |
| 64 | + int64_t axis_value = param_axis.value(); |
| 65 | + if (axis_value < 0) { |
| 66 | + size_t input_scale_index = GetMulScalarInputIndex(mul).value(); |
| 67 | + size_t input_other_index = 1U - input_scale_index; |
| 68 | + int rank = mul->GetNode().InputDefs()[input_other_index]->Shape()->dim_size(); |
| 69 | + axis_value += static_cast<int64_t>(rank); |
| 70 | + } |
| 71 | + return static_cast<uint32_t>(axis_value); |
| 72 | +} |
| 73 | + |
| 74 | +/// @brief Identify scalar input from mul node if present |
| 75 | +/// @param mul Multiply node unit |
| 76 | +/// @return The scalar input float value if found, otherwise std::nullopt |
| 77 | +std::optional<float> ExtractScalarValueFromMul(const GraphViewer& graph_viewer, const NodeUnit* mul) { |
| 78 | + std::optional<size_t> input_scale_index = GetMulScalarInputIndex(mul); |
| 79 | + if (!input_scale_index.has_value()) { |
| 80 | + return std::nullopt; |
| 81 | + } |
| 82 | + const NodeArg* scalar_arg = mul->GetNode().InputDefs()[input_scale_index.value()]; |
| 83 | + if (!graph_viewer.IsConstantInitializer(scalar_arg->Name(), true)) { |
| 84 | + return std::nullopt; |
| 85 | + } |
| 86 | + const auto* scalar_tensor = graph_viewer.GetConstantInitializer(scalar_arg->Name()); |
| 87 | + if (!scalar_tensor) { |
| 88 | + return std::nullopt; |
| 89 | + } |
| 90 | + if (scalar_tensor->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { |
| 91 | + return std::nullopt; |
| 92 | + } |
| 93 | + const auto& raw_data = scalar_tensor->raw_data(); |
| 94 | + if (raw_data.size() != sizeof(float) || reinterpret_cast<uintptr_t>(raw_data.data()) % alignof(float) != 0) { |
| 95 | + return std::nullopt; |
| 96 | + } |
| 97 | + return *reinterpret_cast<const float*>(raw_data.data()); |
| 98 | +} |
| 99 | + |
| 100 | +/// @brief Create or validate the QNN node |
| 101 | +/// @param qnn_model_wrapper QNN model wrapper |
| 102 | +/// @param node_units The node units containing the softmax and mul nodes |
| 103 | +/// @param validate Whether to validate the QNN node |
| 104 | +/// @return Status |
| 105 | +Status CreateOrValidateOnQnn( |
| 106 | + QnnModelWrapper* qnn_model_wrapper, |
| 107 | + gsl::span<const NodeUnit* const> node_units, |
| 108 | + bool validate) { |
| 109 | + const NodeUnit* mul = node_units[0]; |
| 110 | + const NodeUnit* softmax = node_units[1]; |
| 111 | + ORT_RETURN_IF_NOT(mul->OpType() == kOpMul, |
| 112 | + "Expected scale node to be of type Mul, got ", mul->OpType()); |
| 113 | + ORT_RETURN_IF_NOT(softmax->OpType() == kOpSoftmax, |
| 114 | + "Expected softmax node to be of type Softmax, got ", softmax->OpType()); |
| 115 | + size_t input_scale_index = GetMulScalarInputIndex(mul).value(); |
| 116 | + size_t input_other_index = 1U - input_scale_index; |
| 117 | + const NodeUnitIODef& mul_input_other = mul->Inputs()[input_other_index]; |
| 118 | + const NodeUnitIODef& softmax_output = softmax->Outputs()[0]; |
| 119 | + |
| 120 | + std::vector<std::string> param_tensor_names; |
| 121 | + { // axis |
| 122 | + std::optional<uint32_t> axis = GetPositiveSoftmaxAxis(mul, softmax); |
| 123 | + if (axis.has_value()) { |
| 124 | + Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; |
| 125 | + axis_scalar.dataType = QNN_DATATYPE_UINT_32; |
| 126 | + axis_scalar.uint32Value = axis.value(); |
| 127 | + QnnParamWrapper param_wrapper(softmax->Index(), |
| 128 | + softmax->Name(), |
| 129 | + QNN_OP_SOFTMAX_PARAM_AXIS, |
| 130 | + axis_scalar); |
| 131 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param"); |
| 132 | + param_tensor_names.push_back(param_wrapper.GetParamTensorName()); |
| 133 | + } |
| 134 | + } |
| 135 | + { // beta |
| 136 | + NodeAttrHelper softmax_attr_helper(softmax->GetNode()); |
| 137 | + std::optional<float> beta = softmax_attr_helper.GetFloat(QNN_OP_SOFTMAX_PARAM_BETA); |
| 138 | + float scale = ExtractScalarValueFromMul(qnn_model_wrapper->GetGraphViewer(), mul).value_or(1.0f); |
| 139 | + Qnn_Scalar_t beta_scalar = QNN_SCALAR_INIT; |
| 140 | + beta_scalar.dataType = QNN_DATATYPE_FLOAT_32; |
| 141 | + beta_scalar.floatValue = scale * beta.value_or(1.0f); |
| 142 | + QnnParamWrapper param_wrapper(softmax->Index(), |
| 143 | + softmax->Name(), |
| 144 | + QNN_OP_SOFTMAX_PARAM_BETA, |
| 145 | + beta_scalar); |
| 146 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param"); |
| 147 | + param_tensor_names.push_back(param_wrapper.GetParamTensorName()); |
| 148 | + } |
| 149 | + |
| 150 | + QnnTensorWrapper fused_softmax_input; |
| 151 | + QnnTensorWrapper fused_softmax_output; |
| 152 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(mul_input_other, fused_softmax_input)); |
| 153 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(softmax_output, fused_softmax_output)); |
| 154 | + |
| 155 | + if (validate) { |
| 156 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper->ValidateQnnNode(softmax->Name(), |
| 157 | + QNN_OP_PACKAGE_NAME_QTI_AISW, |
| 158 | + QNN_OP_SOFTMAX, |
| 159 | + {fused_softmax_input.GetQnnTensor()}, |
| 160 | + {fused_softmax_output.GetQnnTensor()}, |
| 161 | + {})); |
| 162 | + } else { |
| 163 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_input)), "Failed to add input"); |
| 164 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_output)), "Failed to add output"); |
| 165 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(softmax->Name(), |
| 166 | + QNN_OP_PACKAGE_NAME_QTI_AISW, |
| 167 | + QNN_OP_SOFTMAX, |
| 168 | + {mul_input_other.node_arg.Name()}, |
| 169 | + {softmax_output.node_arg.Name()}, |
| 170 | + std::move(param_tensor_names), |
| 171 | + validate), |
| 172 | + "Failed to add fused " + std::string(kOpSoftmax) + " node."); |
| 173 | + } |
| 174 | + return Status::OK(); |
| 175 | +} |
| 176 | + |
| 177 | +} // namespace |
| 178 | + |
| 179 | +std::unique_ptr<IQnnNodeGroup> ScaleSoftmaxFusion::TryFusion( |
| 180 | + QnnModelWrapper& qnn_model_wrapper, |
| 181 | + const NodeUnit& mul_node_unit, |
| 182 | + const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit, |
| 183 | + const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group, |
| 184 | + [[maybe_unused]] const logging::Logger& logger) { |
| 185 | + if (mul_node_unit.OpType() != kOpMul || mul_node_unit.UnitType() != NodeUnit::Type::SingleNode) { |
| 186 | + return nullptr; |
| 187 | + } |
| 188 | + // Check if the mul node has a scalar input that can fold into the softmax's beta |
| 189 | + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); |
| 190 | + std::optional<float> scalar = ExtractScalarValueFromMul(graph_viewer, &mul_node_unit); |
| 191 | + if (!scalar.has_value()) { |
| 192 | + return nullptr; |
| 193 | + } |
| 194 | + |
| 195 | + // Mul node must have a single Softmax node as child |
| 196 | + const std::array<std::string_view, 1> child_op_types{kOpSoftmax}; |
| 197 | + const NodeUnit* softmax = GetOnlyChildOfType(graph_viewer, mul_node_unit, child_op_types, |
| 198 | + node_to_node_unit, node_unit_to_qnn_node_group); |
| 199 | + if (softmax == nullptr) { |
| 200 | + return nullptr; |
| 201 | + } |
| 202 | + |
| 203 | + std::array<const NodeUnit*, 2> node_unit_array{&mul_node_unit, softmax}; |
| 204 | + auto node_units = gsl::make_span<const NodeUnit*>(node_unit_array.data(), 2); |
| 205 | + if (CreateOrValidateOnQnn(&qnn_model_wrapper, node_units, /*validate=*/true) != Status::OK()) { |
| 206 | + return nullptr; |
| 207 | + } |
| 208 | + return std::make_unique<ScaleSoftmaxFusion>(node_units); |
| 209 | +} |
| 210 | + |
| 211 | +gsl::span<const NodeUnit* const> ScaleSoftmaxFusion::GetNodeUnits() const { |
| 212 | + return gsl::span<const NodeUnit* const>{node_units_.data(), node_units_.size()}; |
| 213 | +} |
| 214 | + |
| 215 | +Status ScaleSoftmaxFusion::IsSupported( |
| 216 | + QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const { |
| 217 | + return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/true); |
| 218 | +} |
| 219 | + |
| 220 | +Status ScaleSoftmaxFusion::AddToModelBuilder( |
| 221 | + QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const { |
| 222 | + return CreateOrValidateOnQnn(&qnn_model_wrapper, GetNodeUnits(), /*validate=*/false); |
| 223 | +} |
| 224 | + |
| 225 | +} // namespace qnn |
| 226 | +} // namespace onnxruntime |
0 commit comments