-
Notifications
You must be signed in to change notification settings - Fork 318
Description
Hey! I’ve been reading through the recent paper "MARS: Unleashing the Power of Variance Reduction for Training Large Models" (Hu et al., 2024). It proposes a really interesting corrected momentum estimator to handle the gradient variance issues we usually see when training LLMs.I noticed Optax doesn't have this implemented yet, so I took a crack at designing a JAX-compliant prototype myself. I think it would be a great addition to optax.contrib.Here is how I’ve set it up so far:State Management: I used a MarsState NamedTuple to track prev_grad needed for the correction term.The Logic: I implemented the variance reduction term:
I’ve got the code ready to go—would you be open to a PR for this?
Reference Link : https://arxiv.org/abs/2411.10438