Skip to content

[Feature Request] Adding the MARS Optimizer (Variance Reduction) from Hu et al. 2024 #1561

@KumarADITHYA123

Description

@KumarADITHYA123

Hey! I’ve been reading through the recent paper "MARS: Unleashing the Power of Variance Reduction for Training Large Models" (Hu et al., 2024). It proposes a really interesting corrected momentum estimator to handle the gradient variance issues we usually see when training LLMs.I noticed Optax doesn't have this implemented yet, so I took a crack at designing a JAX-compliant prototype myself. I think it would be a great addition to optax.contrib.Here is how I’ve set it up so far:State Management: I used a MarsState NamedTuple to track prev_grad needed for the correction term.The Logic: I implemented the variance reduction term: $c_t = \gamma \frac{\beta_1}{1 - \beta_1} (g_t - g_{t-1})$.Stability: As per Section 3.2 of the paper, I applied global norm clipping specifically to the correction term—this seems critical for keeping things stable.Edge Cases: I handled the $t=0$ step implicitly so the correction term vanishes at the start (preventing initial spikes).

I’ve got the code ready to go—would you be open to a PR for this?

Reference Link : https://arxiv.org/abs/2411.10438

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions