Skip to content

Split lowering rules out of jax_primitives.py#2753

Open
paul0403 wants to merge 7 commits intomainfrom
organize-lowering-rule
Open

Split lowering rules out of jax_primitives.py#2753
paul0403 wants to merge 7 commits intomainfrom
organize-lowering-rule

Conversation

@paul0403
Copy link
Copy Markdown
Member

@paul0403 paul0403 commented Apr 24, 2026

Context:
The jax_primitives.py file was getting too heavy.

Description of the Change:
Split the lowering rules out into their own file, primitive_lowering_rules.py.
Remove all NotImplementedError() on the primitives' unused def_impl.
Slightly improve import structure.
Delete a unused function catalyst.Pass.get_options().

Benefits:
Better organization.

@paul0403 paul0403 requested review from albi3ro and kipawaa April 24, 2026 20:19
@github-actions
Copy link
Copy Markdown
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md on your branch with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@paul0403 paul0403 requested a review from a team April 24, 2026 20:22
return ctx.module, ctx.context


def get_mlir_attribute_from_pyval(value):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to a more appropriate place (jax_primitives_utils.py).

I wrote this helper function a while ago, and placed it here in lowering.py somewhat arbitrarily. This lowering.py file needs the CUSTOM_LOWERING_RULES registry, so I was moving it to break out of circular imports.

# pylint: disable=unused-argument,too-many-lines,too-many-statements,protected-access


CUSTOM_LOWERING_RULES = ()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you think of having this be a dictionary and then just converting it via tuple(CUSTOM_LOWERING_RULES.items()) when needed? It would make it so we easily fetch the corresponding lowering rule when we need it for a given primitive.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! Good point, I kept it as a tuple because the connection with jax, jax.interpreters.mlir.LoweringParameters(override_lowering_rules=CUSTOM_LOWERING_RULES)

But it completely didn't occur to me to just convert 😅 I'll add it

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also get rid of the need for global keyword since the registry is mutable now.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 24, 2026

Codecov Report

❌ Patch coverage is 97.33879% with 26 lines in your changes missing coverage. Please review.
✅ Project coverage is 97.00%. Comparing base (5dd28aa) to head (464e8ee).

Files with missing lines Patch % Lines
frontend/catalyst/primitive_lowering_rules.py 97.21% 16 Missing and 10 partials ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2753   +/-   ##
=======================================
  Coverage   96.99%   97.00%           
=======================================
  Files         165      166    +1     
  Lines       18460    18517   +57     
  Branches     1783     1781    -2     
=======================================
+ Hits        17906    17962   +56     
  Misses        398      398           
- Partials      156      157    +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown
Contributor

@dime10 dime10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @paul0403! This is actually a great opportunity to move these files into the tracing module, which was intended to be sub-module for everything jax/tracing related.

The root-level files stem from when we first started the repo and didn't have any subdirectories. After a while we refactored things into proper sub-modules like api-extensions, passes, autograph, etc., but didn't get to refactor the jax code. As a result, it's split across a growing number of files/folders:

  • jax_tracer/primitives/primitives_utils.py
  • jax_extras (for patches/functionality considered "core jax")
  • the started but never completed tracing module
  • and maybe even some things in utils/ ?

If we're able move all of that into one sub-module that would be amazing in terms of code org 😍
As usual, a general reminder to avoid "utils"-style files and folders, in favour of meaningfully grouped code by purpose/functionality.

@paul0403
Copy link
Copy Markdown
Member Author

paul0403 commented Apr 24, 2026

Thanks @paul0403! This is actually a great opportunity to move these files into the tracing module, which was intended to be sub-module for everything jax/tracing related.

The root-level files stem from when we first started the repo and didn't have any subdirectories. After a while we refactored things into proper sub-modules like api-extensions, passes, autograph, etc., but didn't get to refactor the jax code. As a result, it's split across a growing number of files/folders:

  • jax_tracer/primitives/primitives_utils.py
  • jax_extras (for patches/functionality considered "core jax")
  • the started but never completed tracing module
  • and maybe even some things in utils/ ?

If we're able move all of that into one sub-module that would be amazing in terms of code org 😍 As usual, a general reminder to avoid "utils"-style files and folders, in favour of meaningfully grouped code by purpose/functionality.

Ah, nice! Them being at root-level and not having their own subdirectory (and hence not having their own __init__.py) was actually causing a lot of troubles for me and @kipawaa , it caused a bunch of circular imports. A reorganization would certainly be most wonderful!

I'll see what I can come up with 👍

@paul0403
Copy link
Copy Markdown
Member Author

Note that I can do the file reorganization after finishing reference semantics. In the meantime, this PR (which just splits jax_primitives.py) can probably go in first.

Copy link
Copy Markdown
Contributor

@kipawaa kipawaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is so great, thanks for doing this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants