diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 31b0f22340510..6f07ead935f4a 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -36,7 +36,8 @@ namespace optimizer_utils { TODO: This is visible for testing at the moment, but we should rather make it private. */ InlinedVector> GenerateRewriteRules( TransformerLevel level, - const InlinedHashSet& rules_to_disable = {}); + const InlinedHashSet& rules_to_disable = {}, + const bool enable_cast_chain_elimination = false); /** Given a TransformerLevel, this method generates a name for the rule-based graph transformer of that level. */ std::string GenerateRuleBasedTransformerName(TransformerLevel level); @@ -45,7 +46,8 @@ std::string GenerateRuleBasedTransformerName(TransformerLevel level); std::unique_ptr GenerateRuleBasedGraphTransformer( TransformerLevel level, const InlinedHashSet& rules_to_disable, - const InlinedHashSet& compatible_execution_providers); + const InlinedHashSet& compatible_execution_providers, + const bool enable_cast_chain_elimination = false); /** Generates all predefined (both rule-based and non-rule-based) transformers for this level. Any transformers or rewrite rules named in rules_and_transformers_to_disable will be excluded. */ diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 5497d7c71a393..775d4c9c88dfe 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -67,6 +67,10 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; +// Enable or disable Cast chain elimination in graph optimization. "0": disable; "1": enable. The default is "0". +// CastElimination with chain elimination has side effects which may change the inference results. It is disabled by default due to this. +static const char* const kOrtSessionOptionsEnableCastChainElimination = "optimization.enable_cast_chain_elimination"; + // This setting controls whether to enable AheadOfTime function inlining. // AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model // as possible with the help of enabled execution providers. diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 640a6c6d4232a..cc48df4444951 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -610,6 +610,11 @@ bool IsGraphInput(const Graph& graph, const NodeArg* input) { return std::find(graph_inputs.begin(), graph_inputs.end(), input) != graph_inputs.end(); } +bool IsGraphOutput(const Graph& graph, const NodeArg* output) { + const auto& graph_outputs = graph.GetOutputs(); + return std::find(graph_outputs.begin(), graph_outputs.end(), output) != graph_outputs.end(); +} + bool IsInitializer(const Graph& graph, const std::string& name, bool check_outer_scope) { bool is_initializer = false; const ONNX_NAMESPACE::TensorProto* initializer = nullptr; diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index 0b713196203d6..8710519cdc865 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -132,6 +132,9 @@ bool IsOutputUsed(const Node& node, int index); /** Returns true if the graph has the given input.*/ bool IsGraphInput(const Graph& graph, const NodeArg* input); +/** Returns true if the graph has the given output.*/ +bool IsGraphOutput(const Graph& graph, const NodeArg* output); + /** returns true if 'name' is an initializer in 'graph', or an ancestor graph if check_outer_scope is true. @param check_outer_scope If true and 'graph' is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. */ diff --git a/onnxruntime/core/optimizer/cast_chain_elimination.cc b/onnxruntime/core/optimizer/cast_chain_elimination.cc new file mode 100644 index 0000000000000..42ed52f1aa6ce --- /dev/null +++ b/onnxruntime/core/optimizer/cast_chain_elimination.cc @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/logging/logging.h" +#include "core/optimizer/rewrite_rule.h" +#include "core/optimizer/cast_chain_elimination.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph.h" +#include "core/graph/graph_utils.h" + +namespace onnxruntime { + +Status CastChainElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + auto nextNodeIt = node.OutputNodesBegin(); + Node* next = graph.GetNode(nextNodeIt->Index()); + + // We can remove the current node. + graph_utils::RemoveNodeOutputEdges(graph, node); + + NodeArg* last_node_output_def = node.MutableOutputDefs()[0]; + const std::string& last_node_output_tensor_name = last_node_output_def->Name(); + + // Find the matching def slot, so we can wire the final node to the input of the removeable node. + int slot = -1; + + auto& inputs = next->MutableInputDefs(); + for (int i = 0, n = static_cast(inputs.size()); i < n; ++i) { + if (inputs[i]->Name() == last_node_output_tensor_name) { + slot = i; + break; + } + } + + next->MutableInputDefs()[slot] = node.MutableInputDefs()[0]; + + graph_utils::MoveAllNodeInputEdges(graph, node, *next); + + graph.RemoveNode(node.Index()); + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + + return Status::OK(); +} + +bool CastChainElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { + if (!graph_utils::CanRemoveNode(graph, node, logger)) { + return false; + } + + // Skip nodes that don't have 1 output edge. + if (node.GetOutputEdgesCount() != 1) { + return false; + } + + const auto nextNodeIt = node.OutputNodesBegin(); + + const Node* next = graph.GetNode(nextNodeIt->Index()); + + // Skip if the next node is not of type Cast. + if (next->OpType() != "Cast") { + return false; + } + + return true; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/cast_chain_elimination.h b/onnxruntime/core/optimizer/cast_chain_elimination.h new file mode 100644 index 0000000000000..f3c6478969934 --- /dev/null +++ b/onnxruntime/core/optimizer/cast_chain_elimination.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +/** +@Class CastElimination +The transform that will try to find the longest chain of the type Cast where the 'to' attribute has the same data type as the input of the first Cast node in the chain. +E.g. +A ('float32') -> Cast (to='float16') -> Cast (to='int4') -> Cast (to='float32') -> Cast (to='float16') -> B +will reduce to + A ('float32') -> Cast (to='float16') -> B + +All the Cast nodes throughout the path need to have one input and one output to be considered for the fusion. +*/ +class CastChainElimination : public RewriteRule { + public: + CastChainElimination() noexcept : RewriteRule("CastChainElimination") {} + + std::vector TargetOpTypes() const noexcept override { + return {"Cast"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/cast_elimination.cc b/onnxruntime/core/optimizer/cast_elimination.cc index bbcd93472e5b0..64bbacfa22a75 100644 --- a/onnxruntime/core/optimizer/cast_elimination.cc +++ b/onnxruntime/core/optimizer/cast_elimination.cc @@ -31,4 +31,4 @@ bool CastElimination::SatisfyCondition(const Graph& graph, const Node& node, con return optimizer_utils::IsAttributeWithExpectedValue(node, "to", static_cast(input_type->tensor_type().elem_type())); } -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/cast_elimination.h b/onnxruntime/core/optimizer/cast_elimination.h index f1b880d678767..66837f1f9aa01 100644 --- a/onnxruntime/core/optimizer/cast_elimination.h +++ b/onnxruntime/core/optimizer/cast_elimination.h @@ -28,4 +28,4 @@ class CastElimination : public RewriteRule { Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; }; -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index b03959e4f067b..6c87adc67c04d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -24,6 +24,7 @@ #include "core/optimizer/bias_gelu_fusion.h" #include "core/optimizer/bias_softmax_fusion.h" #include "core/optimizer/cast_elimination.h" +#include "core/optimizer/cast_chain_elimination.h" #include "core/optimizer/common_subexpression_elimination.h" #include "core/optimizer/constant_folding.h" #include "core/optimizer/constant_sharing.h" @@ -114,8 +115,10 @@ std::string GenerateRuleBasedTransformerName(TransformerLevel level) { InlinedVector> GenerateRewriteRules( TransformerLevel level, - const InlinedHashSet& rules_to_disable) { + const InlinedHashSet& rules_to_disable, + const bool enable_cast_chain_elimination) { InlinedVector> rules; + switch (level) { case TransformerLevel::Level1: rules.push_back(std::make_unique()); @@ -124,6 +127,9 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + if (enable_cast_chain_elimination) { + rules.push_back(std::make_unique()); + } rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); @@ -171,8 +177,9 @@ InlinedVector> GenerateRewriteRules( std::unique_ptr GenerateRuleBasedGraphTransformer( TransformerLevel level, const InlinedHashSet& rules_to_disable, - const InlinedHashSet& compatible_execution_providers) { - auto rewrite_rules_to_register = GenerateRewriteRules(level, rules_to_disable); + const InlinedHashSet& compatible_execution_providers, + const bool enable_cast_chain_elimination) { + auto rewrite_rules_to_register = GenerateRewriteRules(level, rules_to_disable, enable_cast_chain_elimination); if (rewrite_rules_to_register.empty()) { return nullptr; } @@ -198,6 +205,8 @@ InlinedVector> GenerateTransformers( InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; + const bool enable_cast_chain_elimination = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableCastChainElimination, "0") == "1"; #ifndef DISABLE_CONTRIB_OPS const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; const InlinedHashSet cpu_acl_eps = {onnxruntime::kCpuExecutionProvider, @@ -211,7 +220,7 @@ InlinedVector> GenerateTransformers( // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) // so run them first so there is potentially less for the more intensive optimizations like ConstantFolding, // CommonSubexpressionElimination and TransposeOptimizer to do. - auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}); + auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}, enable_cast_chain_elimination); if (rule_transformer != nullptr) { transformers.emplace_back(std::move(rule_transformer)); } @@ -265,7 +274,7 @@ InlinedVector> GenerateTransformers( } break; case TransformerLevel::Level2: { - auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}); + auto rule_transformer = GenerateRuleBasedGraphTransformer(level, rules_and_transformers_to_disable, {}, enable_cast_chain_elimination); if (rule_transformer != nullptr) { transformers.emplace_back(std::move(rule_transformer)); } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 8b3f55c7df756..35d50cbec678f 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -25,6 +25,7 @@ #include "core/optimizer/bias_gelu_fusion.h" #include "core/optimizer/bias_softmax_fusion.h" #include "core/optimizer/cast_elimination.h" +#include "core/optimizer/cast_chain_elimination.h" #include "core/optimizer/common_subexpression_elimination.h" #include "core/optimizer/concat_slice_elimination.h" #include "core/optimizer/constant_folding.h" @@ -4362,7 +4363,7 @@ TEST_F(GraphTransformationTests, ExpandElimination) { ASSERT_TRUE(op_to_count["Expand"] == 3); } -TEST_F(GraphTransformationTests, CastElimination) { +TEST_F(GraphTransformationTests, CastEliminationSimple) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "cast_elimination.onnx"; std::shared_ptr model; ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK()); @@ -4380,6 +4381,25 @@ TEST_F(GraphTransformationTests, CastElimination) { ASSERT_TRUE(op_to_count["Cast"] == 4); } +TEST_F(GraphTransformationTests, CastChainEliminationRepeatedPattern) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "cast_elimination_complex.onnx"; + + std::shared_ptr model; + ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK()); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Cast"] == 7); + + auto rule_transformer_L1 = std::make_unique("RuleTransformer1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + ASSERT_TRUE(op_to_count["Cast"] == 3); +} + TEST_F(GraphTransformationTests, PreShapeNodeElimination) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "pre_shape_node_elimination.onnx"; std::shared_ptr model; diff --git a/onnxruntime/test/testdata/transform/cast_elimination_complex.onnx b/onnxruntime/test/testdata/transform/cast_elimination_complex.onnx new file mode 100644 index 0000000000000..a76af02be7187 --- /dev/null +++ b/onnxruntime/test/testdata/transform/cast_elimination_complex.onnx @@ -0,0 +1,40 @@ + cast_chain_generator:Ô +, +XX_fp16Cast_X_to_fp16"Cast* +to +  +1 +X_fp16X_fp32Cast_X_to_fp32"Cast* +to  +, +YY_fp32Cast_Y_to_fp32"Cast* +to  +" +X_fp32 +Y_fp32t0_sumAdd"Add +* +t0_sumt1_castCast_1"Cast* +to +  ++ +t1_castt2_castCast_2"Cast* +to  ++ +t2_castt3_castCast_3"Cast* +to  ++ +t3_castt4_castCast_4"Cast* +to +  +& +t4_castZOutputIdentity"IdentityCastChainGraphZ +X +  +NZ +Y +  +Nb +Z +  + +NB \ No newline at end of file