Why use standard LayerNorm over RMSNorm in JAX/Flax? Does XLA fusion negate the algorithmic disadvantage? #5235
Unanswered
AnirudhSKrishnan
asked this question in
Q&A
Replies: 1 comment
-
|
Flax NNX also provides nnx.RMSNorm |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm implementing normalization in Flax and trying to decide between standard LayerNorm and RMSNorm.
Usually RMSNorm is preferred because it's faster (skips the mean calculation/centering). But with jax.jit, doesn't XLA fuse the standard LayerNorm math into a single kernel anyway?
Does the "two-pass" nature of standard LayerNorm actually matter once it's compiled, or is the performance difference negligible compared to the memory bandwidth savings from fusion? Just trying to figure out if it's worth writing a custom RMSNorm or if I should just stick to the standard nn.LayerNorm.
Thanks.
Beta Was this translation helpful? Give feedback.
All reactions