-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Cast Nodes Fusion #24842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Cast Nodes Fusion #24842
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can commit the suggested changes from lintrunner.
We used to have similar Cast removing logic but it causes a lot of accuracy issues. That was reversed in #17953 to be more conservative. I suggest to add an option in onnxruntime_session_options_config_keys.h. The default is off, and user can turn it on if needed. It is easy to process model offline like
So another way to avoid the issue to use such post-processing in model builder. |
Ok, interesting, did not know this could cause accuracy drop. I can go either way, add a feature filter option or just rely on offline model processing. It is nice to have onnxruntime remove it for us automatically though since we often end up with multiple casts when experimenting. |
Description
We might have a case where multiple Cast nodes in the chain cast back to the original type. This fusion will remove extra nodes.
E.g.
A ('float32') -> Cast (to='float16') -> Cast (to='int4') -> Cast (to='float32') -> Cast (to='float16') -> B
will reduce to
A ('float32') -> Cast (to='float16') -> B
All the Cast nodes throughout the path need to have one input and one output to be considered for the fusion.
Motivation and Context
Gemma3 ONNX models used to have double casting, and many new models created by the model builder might have as well. Extra Casts might reduce accuracy and increase inference time.