diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 31b0f22340510..71b603f166c1c 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -36,7 +36,8 @@ namespace optimizer_utils { TODO: This is visible for testing at the moment, but we should rather make it private. */ InlinedVector> GenerateRewriteRules( TransformerLevel level, - const InlinedHashSet& rules_to_disable = {}); + const InlinedHashSet& rules_to_disable = {}, + bool enable_cast_chain_elimination = false); /** Given a TransformerLevel, this method generates a name for the rule-based graph transformer of that level. */ std::string GenerateRuleBasedTransformerName(TransformerLevel level); diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 5497d7c71a393..775d4c9c88dfe 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -67,6 +67,10 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; +// Enable or disable Cast chain elimination in graph optimization. "0": disable; "1": enable. The default is "0". +// CastElimination with chain elimination has side effects which may change the inference results. It is disabled by default due to this. +static const char* const kOrtSessionOptionsEnableCastChainElimination = "optimization.enable_cast_chain_elimination"; + // This setting controls whether to enable AheadOfTime function inlining. // AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model // as possible with the help of enabled execution providers. diff --git a/onnxruntime/core/optimizer/cast_elimination.cc b/onnxruntime/core/optimizer/cast_elimination.cc index bbcd93472e5b0..7830a40f9a634 100644 --- a/onnxruntime/core/optimizer/cast_elimination.cc +++ b/onnxruntime/core/optimizer/cast_elimination.cc @@ -11,8 +11,94 @@ namespace onnxruntime { Status CastElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { - if (graph_utils::RemoveNode(graph, node)) { + const auto* input_type = node.InputDefs()[0]->TypeAsProto(); + if (input_type == nullptr || !input_type->tensor_type().has_elem_type()) { + return Status::OK(); + } + + // Check if we can immediately remove a very common case (casting to the same type as the input). + if (optimizer_utils::IsAttributeWithExpectedValue(node, "to", static_cast(input_type->tensor_type().elem_type()))) { + graph_utils::RemoveNode(graph, node); rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + + return Status::OK(); + } + + // Check if we can continue + if (!enable_chain_elimination_) { + return Status::OK(); + } + + // If not, find the longest chain that casts to the input type, if it exists. + Node* current = &node; + Node* final_non_cast_node = &node; + int matching_elem_type = input_type->tensor_type().elem_type(); + + while (current->OpType() == "Cast") { + const auto& to_attr = current->GetAttributes().at("to"); + + // A rare case when the Cast node output is branching out. + // We don't really want to deal with this complexity, hence we will skip it. + if (current->GetOutputEdgesCount() > 1) { + return Status::OK(); + } + + auto it = current->OutputNodesBegin(); + if (it == current->OutputNodesEnd()) { + break; + } + current = const_cast(&(*it)); + + // We found the repeating pattern. + if (to_attr.i() == matching_elem_type) { + final_non_cast_node = current; + } + } + + // No repeating pattern was found. + if (node.Index() == final_non_cast_node->Index()) { + return Status::OK(); + } + + std::vector to_remove; + current = &node; + + // Collect nodes for removal. + while (current != final_non_cast_node && current->OpType() == "Cast") { + to_remove.push_back(current); + auto it = current->OutputNodesBegin(); + if (it == current->OutputNodesEnd()) + break; + current = const_cast(&*it); + } + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + + // First remove all outbound edges. + for (Node* n : to_remove) { + graph_utils::RemoveNodeOutputEdges(graph, *n); + } + + NodeArg* last_node_output_def = to_remove.back()->MutableOutputDefs()[0]; + const std::string& last_node_output_tensor_name = last_node_output_def->Name(); + + // Find the matching def slot, so we can wire the final node to the input of the first removeable node. + int slot = -1; + auto& inputs = final_non_cast_node->MutableInputDefs(); + for (int i = 0, n = static_cast(inputs.size()); i < n; ++i) { + if (inputs[i]->Name() == last_node_output_tensor_name) { + slot = i; + break; + } + } + + final_non_cast_node->MutableInputDefs()[slot] = to_remove[0]->MutableInputDefs()[0]; + + graph_utils::MoveAllNodeInputEdges(graph, *to_remove[0], *final_non_cast_node); + + // Finally, remove the nodes itself. + for (Node* n : to_remove) { + graph.RemoveNode(n->Index()); } return Status::OK(); @@ -22,13 +108,7 @@ bool CastElimination::SatisfyCondition(const Graph& graph, const Node& node, con if (!graph_utils::CanRemoveNode(graph, node, logger)) { return false; } - - const auto* input_type = node.InputDefs()[0]->TypeAsProto(); - if (input_type == nullptr || !input_type->tensor_type().has_elem_type()) { - return false; - } - - return optimizer_utils::IsAttributeWithExpectedValue(node, "to", static_cast(input_type->tensor_type().elem_type())); + return true; } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/cast_elimination.h b/onnxruntime/core/optimizer/cast_elimination.h index f1b880d678767..9412ad934a083 100644 --- a/onnxruntime/core/optimizer/cast_elimination.h +++ b/onnxruntime/core/optimizer/cast_elimination.h @@ -11,12 +11,22 @@ namespace onnxruntime { @Class CastElimination Rewrite rule that eliminates Cast nodes if 'to' attribute has same data type as input tensor data type. +Additionally, it will try to find the longest chain where the 'to' attribute has the same data type as the input of the first Cast node in the chain. +E.g. +A ('float32') -> Cast (to='float16') -> Cast (to='int4') -> Cast (to='float32') -> Cast (to='float16') -> B +will reduce to + A ('float32') -> Cast (to='float16') -> B + +All the Cast nodes throughout the path need to have one input and one output to be considered for the fusion. It is attempted to be triggered only on nodes with op type "Cast". */ class CastElimination : public RewriteRule { + private: + const bool enable_chain_elimination_; + public: - CastElimination() noexcept : RewriteRule("CastElimination") {} + CastElimination(bool enable_chain_elimination = false) noexcept : RewriteRule("CastElimination"), enable_chain_elimination_(enable_chain_elimination) {} std::vector TargetOpTypes() const noexcept override { return {"Cast"}; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index b03959e4f067b..ec3c89c31a336 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -114,7 +114,8 @@ std::string GenerateRuleBasedTransformerName(TransformerLevel level) { InlinedVector> GenerateRewriteRules( TransformerLevel level, - const InlinedHashSet& rules_to_disable) { + const InlinedHashSet& rules_to_disable, + bool enable_cast_chain_elimination) { InlinedVector> rules; switch (level) { case TransformerLevel::Level1: @@ -123,7 +124,7 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); - rules.push_back(std::make_unique()); + rules.push_back(std::make_unique(enable_cast_chain_elimination)); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); @@ -171,8 +172,11 @@ InlinedVector> GenerateRewriteRules( std::unique_ptr GenerateRuleBasedGraphTransformer( TransformerLevel level, const InlinedHashSet& rules_to_disable, - const InlinedHashSet& compatible_execution_providers) { - auto rewrite_rules_to_register = GenerateRewriteRules(level, rules_to_disable); + const InlinedHashSet& compatible_execution_providers, + const SessionOptions& session_options) { + const bool enable_cast_chain_elimination = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableCastChainElimination, "0") == "1"; + + auto rewrite_rules_to_register = GenerateRewriteRules(level, rules_to_disable, enable_cast_chain_elimination); if (rewrite_rules_to_register.empty()) { return nullptr; } @@ -211,7 +215,7 @@ InlinedVector> GenerateTransformers( // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) // so run them first so there is potentially less for the more intensive optimizations like ConstantFolding, // CommonSubexpressionElimination and TransposeOptimizer to do. - auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}); + auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}, session_options); if (rule_transformer != nullptr) { transformers.emplace_back(std::move(rule_transformer)); } @@ -265,7 +269,7 @@ InlinedVector> GenerateTransformers( } break; case TransformerLevel::Level2: { - auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}); + auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}, session_options); if (rule_transformer != nullptr) { transformers.emplace_back(std::move(rule_transformer)); } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 8b3f55c7df756..367934f406428 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -4362,7 +4362,7 @@ TEST_F(GraphTransformationTests, ExpandElimination) { ASSERT_TRUE(op_to_count["Expand"] == 3); } -TEST_F(GraphTransformationTests, CastElimination) { +TEST_F(GraphTransformationTests, CastEliminationSimple) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "cast_elimination.onnx"; std::shared_ptr model; ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK()); @@ -4380,6 +4380,25 @@ TEST_F(GraphTransformationTests, CastElimination) { ASSERT_TRUE(op_to_count["Cast"] == 4); } +TEST_F(GraphTransformationTests, CastEliminationRepeatedPattern) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "cast_elimination_complex.onnx"; + + std::shared_ptr model; + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK()); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Cast"] == 7); + + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique(true))); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Cast"] == 1); +} + TEST_F(GraphTransformationTests, PreShapeNodeElimination) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "pre_shape_node_elimination.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/cast_elimination_complex.onnx b/onnxruntime/test/testdata/transform/cast_elimination_complex.onnx new file mode 100644 index 0000000000000..a76af02be7187 --- /dev/null +++ b/onnxruntime/test/testdata/transform/cast_elimination_complex.onnx @@ -0,0 +1,40 @@ + cast_chain_generator:Ô +, +XX_fp16Cast_X_to_fp16"Cast* +to +  +1 +X_fp16X_fp32Cast_X_to_fp32"Cast* +to  +, +YY_fp32Cast_Y_to_fp32"Cast* +to  +" +X_fp32 +Y_fp32t0_sumAdd"Add +* +t0_sumt1_castCast_1"Cast* +to +  ++ +t1_castt2_castCast_2"Cast* +to  ++ +t2_castt3_castCast_3"Cast* +to  ++ +t3_castt4_castCast_4"Cast* +to +  +& +t4_castZOutputIdentity"IdentityCastChainGraphZ +X +  +NZ +Y +  +Nb +Z +  + +NB \ No newline at end of file