Skip to content

Correct usage of look-ahead optimizer #1429

@SNMS95

Description

@SNMS95

I want to use the look-ahead optimizer but it does not seem to fit the optax pattern exactly.

import optax
import jax
import jax.numpy as jnp

def fn_to_optimize(x):
    x = x.fast
    return jnp.sum((x) ** 2)

params = jnp.array([2.0, 2.0])
fast_optimizer = optax.adam(1e-1)
solver = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)
# params = optax.LookaheadParams.init_synced(params)
state = solver.init(params)

for step in range(100):
    loss, grads = jax.value_and_grad(fn_to_optimize)(params)
    updates, state = solver.update(grads, state, params)
    params = optax.apply_updates(params, updates)
    if step % 10 == 0:
        print(f"Step {step}, Loss: {loss}, Params: {params}")

I would have expected it to be a drop in replacement like the other optimizers but that'd does not seem to be the case.

When I look at the source code (

def init_fn(params: base.Params) -> LookaheadState:
), the init function is very clear.
However, the update function is confusing.
It expects params:LookaheadParams unlike init

This seems to be right way to do it but this is not very intuitive.

import optax
import jax
import jax.numpy as jnp

def fn_to_optimize(x):
    return jnp.sum((x) ** 2)

params = jnp.array([2.0, 2.0])
fast_optimizer = optax.adam(1e-1)
solver = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)
params = optax.LookaheadParams.init_synced(params)
state = solver.init(params)

for step in range(100):
    loss, grads = jax.value_and_grad(fn_to_optimize)(params.fast)
    updates, state = solver.update(grads, state, params)
    params = optax.apply_updates(params, updates)
    if step % 1 == 0:
        print(f"Step {step}, Loss: {loss}, Params: {params}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    type:supportFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions