Skip to content

Cast Nodes Fusion #24842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 83 additions & 8 deletions onnxruntime/core/optimizer/cast_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,89 @@
namespace onnxruntime {

Status CastElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph_utils::RemoveNode(graph, node)) {
const auto* input_type = node.InputDefs()[0]->TypeAsProto();
if (input_type == nullptr || !input_type->tensor_type().has_elem_type()) {
return Status::OK();
}

// Check if we can immediately remove a very common case (casting to the same type as the input).
if (optimizer_utils::IsAttributeWithExpectedValue(node, "to", static_cast<int64_t>(input_type->tensor_type().elem_type()))) {
graph_utils::RemoveNode(graph, node);
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;

return Status::OK();
}

// If not, find the longest chain that casts to the input type, if it exists.
Node* current = &node;
Node* final_non_cast_node = &node;
int matching_elem_type = input_type->tensor_type().elem_type();

while (current->OpType() == "Cast") {
const auto& to_attr = current->GetAttributes().at("to");

// A rare case when the Cast node output is branching out.
// We don't really want to deal with this complexity, hence we will skip it.
if (current->GetOutputEdgesCount() > 1) {
return Status::OK();
}

auto it = current->OutputNodesBegin();
if (it == current->OutputNodesEnd()) {
break;
}
current = const_cast<Node*>(&(*it));

// We found the repeating pattern.
if (to_attr.i() == matching_elem_type) {
final_non_cast_node = current;
}
}

// No repeating pattern was found.
if (node.Index() == final_non_cast_node->Index()) {
return Status::OK();
}

std::vector<Node*> to_remove;
current = &node;

// Collect nodes for removal.
while (current != final_non_cast_node && current->OpType() == "Cast") {
to_remove.push_back(current);
auto it = current->OutputNodesBegin();
if (it == current->OutputNodesEnd())
break;
current = const_cast<Node*>(&*it);
}

rule_effect = RewriteRuleEffect::kRemovedCurrentNode;

// First remove all outbound edges.
for (Node* n : to_remove) {
graph_utils::RemoveNodeOutputEdges(graph, *n);
}

NodeArg* last_node_output_def = to_remove.back()->MutableOutputDefs()[0];
const std::string& last_node_output_tensor_name = last_node_output_def->Name();

// Find the matching def slot, so we can wire the final node to the input of the first removeable node.
int slot = -1;
auto& inputs = final_non_cast_node->MutableInputDefs();
for (int i = 0, n = static_cast<int>(inputs.size()); i < n; ++i) {
if (inputs[i]->Name() == last_node_output_tensor_name) {
slot = i;
break;
}
}

final_non_cast_node->MutableInputDefs()[slot] = to_remove[0]->MutableInputDefs()[0];

graph_utils::MoveAllNodeInputEdges(graph, *to_remove[0], *final_non_cast_node);

// Finally, remove the nodes itself.
for (Node* n : to_remove) {
graph.RemoveNode(n->Index());
}

return Status::OK();
Expand All @@ -22,13 +103,7 @@ bool CastElimination::SatisfyCondition(const Graph& graph, const Node& node, con
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
return false;
}

const auto* input_type = node.InputDefs()[0]->TypeAsProto();
if (input_type == nullptr || !input_type->tensor_type().has_elem_type()) {
return false;
}

return optimizer_utils::IsAttributeWithExpectedValue(node, "to", static_cast<int64_t>(input_type->tensor_type().elem_type()));
return true;
}

} // namespace onnxruntime
7 changes: 7 additions & 0 deletions onnxruntime/core/optimizer/cast_elimination.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ namespace onnxruntime {
@Class CastElimination

Rewrite rule that eliminates Cast nodes if 'to' attribute has same data type as input tensor data type.
Additionally, it will try to find the longest chain where the 'to' attribute has the same data type as the input of the first Cast node in the chain.
E.g.
A ('float32') -> Cast (to='float16') -> Cast (to='int4') -> Cast (to='float32') -> Cast (to='float16') -> B
will reduce to
A ('float32') -> Cast (to='float16') -> B

All the Cast nodes throughout the path need to have one input and one output to be considered for the fusion.

It is attempted to be triggered only on nodes with op type "Cast".
*/
Expand Down
21 changes: 20 additions & 1 deletion onnxruntime/test/optimizer/graph_transform_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4362,7 +4362,7 @@ TEST_F(GraphTransformationTests, ExpandElimination) {
ASSERT_TRUE(op_to_count["Expand"] == 3);
}

TEST_F(GraphTransformationTests, CastElimination) {
TEST_F(GraphTransformationTests, CastEliminationSimple) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "cast_elimination.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
Expand All @@ -4380,6 +4380,25 @@ TEST_F(GraphTransformationTests, CastElimination) {
ASSERT_TRUE(op_to_count["Cast"] == 4);
}

TEST_F(GraphTransformationTests, CastEliminationRepeatedPattern) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "cast_elimination_complex.onnx";

std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model, nullptr, *logger_).IsOK());
Graph& graph = model->MainGraph();
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Cast"] == 7);

auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<CastElimination>()));
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));

op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Cast"] == 1);
}

TEST_F(GraphTransformationTests, PreShapeNodeElimination) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "pre_shape_node_elimination.onnx";
std::shared_ptr<Model> model;
Expand Down
40 changes: 40 additions & 0 deletions onnxruntime/test/testdata/transform/cast_elimination_complex.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
 cast_chain_generator:Ô
,
XX_fp16Cast_X_to_fp16"Cast*
to
 
1
X_fp16X_fp32Cast_X_to_fp32"Cast*
to 
,
YY_fp32Cast_Y_to_fp32"Cast*
to 
"
X_fp32
Y_fp32t0_sumAdd"Add
*
t0_sumt1_castCast_1"Cast*
to
 
+
t1_castt2_castCast_2"Cast*
to 
+
t2_castt3_castCast_3"Cast*
to 
+
t3_castt4_castCast_4"Cast*
to
 
&
t4_castZOutputIdentity"IdentityCastChainGraphZ
X

NZ
Y

Nb
Z


NB
Loading