Skip to content

Cast Nodes Fusion #24842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Cast Nodes Fusion #24842

wants to merge 9 commits into from

Conversation

nenad1002
Copy link
Contributor

@nenad1002 nenad1002 commented May 22, 2025

Description

We might have a case where multiple Cast nodes in the chain cast back to the original type. This fusion will remove extra nodes.
E.g.
A ('float32') -> Cast (to='float16') -> Cast (to='int4') -> Cast (to='float32') -> Cast (to='float16') -> B
will reduce to
A ('float32') -> Cast (to='float16') -> B
All the Cast nodes throughout the path need to have one input and one output to be considered for the fusion.

Motivation and Context

Gemma3 ONNX models used to have double casting, and many new models created by the model builder might have as well. Extra Casts might reduce accuracy and increase inference time.

@nenad1002 nenad1002 changed the title [DO NOT REVIEW YET] Fusion Cast [DO NOT REVIEW YET] Draft PR - Fusion Cast May 22, 2025
@nenad1002 nenad1002 changed the title [DO NOT REVIEW YET] Draft PR - Fusion Cast Draft PR - Fusion Cast May 22, 2025
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

@nenad1002 nenad1002 changed the title Draft PR - Fusion Cast Cast Nodes Fusion May 23, 2025
@nenad1002 nenad1002 marked this pull request as ready for review May 23, 2025 15:04
@nenad1002 nenad1002 requested a review from tianleiwu May 23, 2025 16:08
@tianleiwu
Copy link
Contributor

tianleiwu commented May 23, 2025

We used to have similar Cast removing logic but it causes a lot of accuracy issues. That was reversed in #17953 to be more conservative. I suggest to add an option in onnxruntime_session_options_config_keys.h. The default is off, and user can turn it on if needed.

It is easy to process model offline like

def remove_cascaded_cast_nodes(self):
.
So another way to avoid the issue to use such post-processing in model builder.

@nenad1002
Copy link
Contributor Author

We used to have similar Cast removing logic but it causes a lot of accuracy issues. That was reversed in #17953 to be more conservative. I suggest to add an option in onnxruntime_session_options_config_keys.h. The default is off, and user can turn it on if needed.

It is easy to process model offline like

def remove_cascaded_cast_nodes(self):

.
So another way to avoid the issue to use such post-processing in model builder.

Ok, interesting, did not know this could cause accuracy drop. I can go either way, add a feature filter option or just rely on offline model processing. It is nice to have onnxruntime remove it for us automatically though since we often end up with multiple casts when experimenting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants