|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.h" |
| 5 | + |
| 6 | +#include <gsl/gsl> |
| 7 | +#include <cassert> |
| 8 | +#include <optional> |
| 9 | +#include <utility> |
| 10 | +#include <string> |
| 11 | +#include <array> |
| 12 | +#include <vector> |
| 13 | + |
| 14 | +#include "core/providers/qnn/builder/qnn_utils.h" |
| 15 | +#include "core/providers/qnn/builder/op_builder_factory.h" |
| 16 | +#include "core/providers/qnn/builder/qnn_node_group/utils.h" |
| 17 | +#include "core/providers/qnn/builder/qnn_model_wrapper.h" |
| 18 | +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" |
| 19 | + |
| 20 | +namespace onnxruntime { |
| 21 | +namespace qnn { |
| 22 | +namespace { |
| 23 | + |
| 24 | +constexpr char kOpMul[] = "Mul"; |
| 25 | +constexpr char kOpSoftmax[] = "Softmax"; |
| 26 | + |
| 27 | +/// @brief Get the index of the scalar input in the mul node |
| 28 | +/// @param mul Multiply node unit |
| 29 | +/// @return The index of the scalar input (0 or 1) if found, otherwise std::nullopt |
| 30 | +std::optional<size_t> GetMulScalarInputIndex(const NodeUnit* mul) { |
| 31 | + const NodeArg* mul_x = mul->GetNode().InputDefs()[0]; |
| 32 | + const NodeArg* mul_y = mul->GetNode().InputDefs()[1]; |
| 33 | + auto mul_x_shape = utils::GetTensorProtoShape(*mul_x->Shape()); |
| 34 | + auto mul_y_shape = utils::GetTensorProtoShape(*mul_y->Shape()); |
| 35 | + bool is_x_scalar = mul_x_shape.NumDimensions() == 0; |
| 36 | + bool is_y_scalar = mul_y_shape.NumDimensions() == 0; |
| 37 | + if (!is_x_scalar && !is_y_scalar) { |
| 38 | + return std::nullopt; |
| 39 | + } |
| 40 | + return is_y_scalar ? 1 : 0; |
| 41 | +} |
| 42 | + |
| 43 | +/// @brief Get the axis for softmax |
| 44 | +/// @param node_units The node units containing the softmax and mul nodes |
| 45 | +/// @return The axis for softmax |
| 46 | +std::optional<uint32_t> GetPositiveSoftmaxAxis(const NodeUnit* mul, const NodeUnit* softmax) { |
| 47 | + NodeAttrHelper softmax_attr_helper(softmax->GetNode()); |
| 48 | + std::optional<int64_t> param_axis = softmax_attr_helper.GetInt64(QNN_OP_SOFTMAX_PARAM_AXIS); |
| 49 | + if (!param_axis.has_value()) { |
| 50 | + return std::nullopt; |
| 51 | + } |
| 52 | + int64_t axis_value = param_axis.value(); |
| 53 | + if (axis_value < 0) { |
| 54 | + size_t input_scale_index = GetMulScalarInputIndex(mul).value(); |
| 55 | + size_t input_other_index = 1 - input_scale_index; |
| 56 | + int rank = mul->GetNode().InputDefs()[input_other_index]->Shape()->dim_size(); |
| 57 | + axis_value += static_cast<int64_t>(rank); |
| 58 | + } |
| 59 | + return static_cast<uint32_t>(axis_value); |
| 60 | +} |
| 61 | + |
| 62 | +/// @brief Identify scalar input from mul node if present |
| 63 | +/// @param mul Multiply node unit |
| 64 | +/// @return The scalar input float value if found, otherwise std::nullopt |
| 65 | +std::optional<float> ExtractScalarValueFromMul(const GraphViewer& graph_viewer, const NodeUnit* mul) { |
| 66 | + std::optional<size_t> input_scale_index = GetMulScalarInputIndex(mul); |
| 67 | + if (!input_scale_index.has_value()) { |
| 68 | + return std::nullopt; |
| 69 | + } |
| 70 | + const NodeArg* scalar_arg = mul->GetNode().InputDefs()[input_scale_index.value()]; |
| 71 | + if (!graph_viewer.IsConstantInitializer(scalar_arg->Name(), true)) { |
| 72 | + return std::nullopt; |
| 73 | + } |
| 74 | + const auto* scalar_tensor = graph_viewer.GetConstantInitializer(scalar_arg->Name()); |
| 75 | + if (!scalar_tensor) { |
| 76 | + return std::nullopt; |
| 77 | + } |
| 78 | + if (scalar_tensor->data_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { |
| 79 | + return std::nullopt; |
| 80 | + } |
| 81 | + const auto& raw_data = scalar_tensor->raw_data(); |
| 82 | + if (raw_data.size() != sizeof(float) || reinterpret_cast<uintptr_t>(raw_data.data()) % alignof(float) != 0) { |
| 83 | + return std::nullopt; |
| 84 | + } |
| 85 | + return *reinterpret_cast<const float*>(raw_data.data()); |
| 86 | +} |
| 87 | + |
| 88 | +/// @brief Create or validate the QNN node |
| 89 | +/// @param qnn_model_wrapper QNN model wrapper |
| 90 | +/// @param node_units The node units containing the softmax and mul nodes |
| 91 | +/// @param validate Whether to validate the QNN node |
| 92 | +/// @return Status |
| 93 | +Status CreateOrValidateOnQnn( |
| 94 | + QnnModelWrapper* qnn_model_wrapper, |
| 95 | + std::array<const NodeUnit*, 2> node_units, |
| 96 | + bool validate) { |
| 97 | + const NodeUnit* mul = node_units[0]; |
| 98 | + const NodeUnit* softmax = node_units[1]; |
| 99 | + ORT_RETURN_IF_NOT(mul->OpType() == kOpMul, |
| 100 | + "Expected scale node to be of type Mul, got ", mul->OpType()); |
| 101 | + ORT_RETURN_IF_NOT(softmax->OpType() == kOpSoftmax, |
| 102 | + "Expected softmax node to be of type Softmax, got ", softmax->OpType()); |
| 103 | + size_t input_scale_index = GetMulScalarInputIndex(mul).value(); |
| 104 | + size_t input_other_index = 1 - input_scale_index; |
| 105 | + const NodeUnitIODef& mul_input_other = mul->Inputs()[input_other_index]; |
| 106 | + const NodeUnitIODef& softmax_output = softmax->Outputs()[0]; |
| 107 | + |
| 108 | + std::vector<std::string> param_tensor_names; |
| 109 | + { // axis |
| 110 | + std::optional<uint32_t> axis = GetPositiveSoftmaxAxis(mul, softmax); |
| 111 | + if (axis.has_value()) { |
| 112 | + Qnn_Scalar_t axis_scalar = QNN_SCALAR_INIT; |
| 113 | + axis_scalar.dataType = QNN_DATATYPE_UINT_32; |
| 114 | + axis_scalar.uint32Value = axis.value(); |
| 115 | + QnnParamWrapper param_wrapper(softmax->Index(), |
| 116 | + softmax->Name(), |
| 117 | + QNN_OP_SOFTMAX_PARAM_AXIS, |
| 118 | + axis_scalar); |
| 119 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param"); |
| 120 | + param_tensor_names.push_back(param_wrapper.GetParamTensorName()); |
| 121 | + } |
| 122 | + } |
| 123 | + { // beta |
| 124 | + NodeAttrHelper softmax_attr_helper(softmax->GetNode()); |
| 125 | + std::optional<float> beta = softmax_attr_helper.GetFloat(QNN_OP_SOFTMAX_PARAM_BETA); |
| 126 | + float scale = ExtractScalarValueFromMul(qnn_model_wrapper->GetGraphViewer(), mul).value_or(1.0f); |
| 127 | + Qnn_Scalar_t beta_scalar = QNN_SCALAR_INIT; |
| 128 | + beta_scalar.dataType = QNN_DATATYPE_FLOAT_32; |
| 129 | + beta_scalar.floatValue = scale * beta.value_or(1.0f); |
| 130 | + QnnParamWrapper param_wrapper(softmax->Index(), |
| 131 | + softmax->Name(), |
| 132 | + QNN_OP_SOFTMAX_PARAM_BETA, |
| 133 | + beta_scalar); |
| 134 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_wrapper)), "Failed to add param"); |
| 135 | + param_tensor_names.push_back(param_wrapper.GetParamTensorName()); |
| 136 | + } |
| 137 | + |
| 138 | + QnnTensorWrapper fused_softmax_input; |
| 139 | + QnnTensorWrapper fused_softmax_output; |
| 140 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(mul_input_other, fused_softmax_input)); |
| 141 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(softmax_output, fused_softmax_output)); |
| 142 | + |
| 143 | + if (validate) { |
| 144 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper->ValidateQnnNode(softmax->Name(), |
| 145 | + QNN_OP_PACKAGE_NAME_QTI_AISW, |
| 146 | + QNN_OP_SOFTMAX, |
| 147 | + {fused_softmax_input.GetQnnTensor()}, |
| 148 | + {fused_softmax_output.GetQnnTensor()}, |
| 149 | + {})); |
| 150 | + } else { |
| 151 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_input)), "Failed to add input"); |
| 152 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_output)), "Failed to add output"); |
| 153 | + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(softmax->Name(), |
| 154 | + QNN_OP_PACKAGE_NAME_QTI_AISW, |
| 155 | + QNN_OP_SOFTMAX, |
| 156 | + {mul_input_other.node_arg.Name()}, |
| 157 | + {softmax_output.node_arg.Name()}, |
| 158 | + std::move(param_tensor_names), |
| 159 | + validate), |
| 160 | + "Failed to add fused " + std::string(kOpSoftmax) + " node."); |
| 161 | + } |
| 162 | + return Status::OK(); |
| 163 | +} |
| 164 | + |
| 165 | +} // namespace |
| 166 | + |
| 167 | +std::unique_ptr<IQnnNodeGroup> ScaleSoftmaxFusion::TryFusion( |
| 168 | + QnnModelWrapper& qnn_model_wrapper, |
| 169 | + const NodeUnit& mul_node_unit, |
| 170 | + const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit, |
| 171 | + const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group, |
| 172 | + [[maybe_unused]] const logging::Logger& logger) { |
| 173 | + if (mul_node_unit.OpType() != kOpMul || mul_node_unit.UnitType() != NodeUnit::Type::SingleNode) { |
| 174 | + return nullptr; |
| 175 | + } |
| 176 | + // Check if the mul node has a scalar input that can fold into the softmax's beta |
| 177 | + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); |
| 178 | + std::optional<float> scalar = ExtractScalarValueFromMul(graph_viewer, &mul_node_unit); |
| 179 | + if (!scalar.has_value()) { |
| 180 | + return nullptr; |
| 181 | + } |
| 182 | + |
| 183 | + // Mul node must have a single Softmax node as child |
| 184 | + const std::array<std::string_view, 1> child_op_types{kOpSoftmax}; |
| 185 | + const NodeUnit* softmax = GetOnlyChildOfType(graph_viewer, mul_node_unit, child_op_types, |
| 186 | + node_to_node_unit, node_unit_to_qnn_node_group); |
| 187 | + if (softmax == nullptr) { |
| 188 | + return nullptr; |
| 189 | + } |
| 190 | + |
| 191 | + std::array<const NodeUnit*, 2> node_units{&mul_node_unit, softmax}; |
| 192 | + if (CreateOrValidateOnQnn(&qnn_model_wrapper, node_units, /*validate=*/true) != Status::OK()) { |
| 193 | + return nullptr; |
| 194 | + } |
| 195 | + |
| 196 | + return std::make_unique<ScaleSoftmaxFusion>(node_units); |
| 197 | +} |
| 198 | + |
| 199 | +Status ScaleSoftmaxFusion::IsSupported( |
| 200 | + QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const { |
| 201 | + return CreateOrValidateOnQnn(&qnn_model_wrapper, node_units_, /*validate=*/true); |
| 202 | +} |
| 203 | + |
| 204 | +Status ScaleSoftmaxFusion::AddToModelBuilder( |
| 205 | + QnnModelWrapper& qnn_model_wrapper, [[maybe_unused]] const logging::Logger& logger) const { |
| 206 | + return CreateOrValidateOnQnn(&qnn_model_wrapper, node_units_, /*validate=*/false); |
| 207 | +} |
| 208 | + |
| 209 | +} // namespace qnn |
| 210 | +} // namespace onnxruntime |
0 commit comments