-
Notifications
You must be signed in to change notification settings - Fork 318
Open
Labels
type:supportFurther information is requestedFurther information is requested
Description
I want to use the look-ahead optimizer but it does not seem to fit the optax pattern exactly.
import optax
import jax
import jax.numpy as jnp
def fn_to_optimize(x):
x = x.fast
return jnp.sum((x) ** 2)
params = jnp.array([2.0, 2.0])
fast_optimizer = optax.adam(1e-1)
solver = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)
# params = optax.LookaheadParams.init_synced(params)
state = solver.init(params)
for step in range(100):
loss, grads = jax.value_and_grad(fn_to_optimize)(params)
updates, state = solver.update(grads, state, params)
params = optax.apply_updates(params, updates)
if step % 10 == 0:
print(f"Step {step}, Loss: {loss}, Params: {params}")I would have expected it to be a drop in replacement like the other optimizers but that'd does not seem to be the case.
When I look at the source code (
Line 99 in 5bd9095
| def init_fn(params: base.Params) -> LookaheadState: |
However, the update function is confusing.
It expects
params:LookaheadParams unlike init
This seems to be right way to do it but this is not very intuitive.
import optax
import jax
import jax.numpy as jnp
def fn_to_optimize(x):
return jnp.sum((x) ** 2)
params = jnp.array([2.0, 2.0])
fast_optimizer = optax.adam(1e-1)
solver = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)
params = optax.LookaheadParams.init_synced(params)
state = solver.init(params)
for step in range(100):
loss, grads = jax.value_and_grad(fn_to_optimize)(params.fast)
updates, state = solver.update(grads, state, params)
params = optax.apply_updates(params, updates)
if step % 1 == 0:
print(f"Step {step}, Loss: {loss}, Params: {params}")Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
type:supportFurther information is requestedFurther information is requested