Skip to content

Batching at every gradient step #967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mjyshin opened this issue Feb 18, 2025 · 5 comments
Open

Batching at every gradient step #967

mjyshin opened this issue Feb 18, 2025 · 5 comments
Labels
question Further information is requested

Comments

@mjyshin
Copy link

mjyshin commented Feb 18, 2025

I am new to neural differential equations and have been going through some tutorials to better understand them. I noticed that in Python's Diffrax tutorial, they use a batching scheme for training, where every gradient step seems to be using 32 trajectories. This runs surprisingly fast, and when I tried to implement this in Julia, either via Optimization (setting maxiters=1 in solve) or via Lux.Training directly, it takes forever.

Am I totally misunderstanding something from the tutorial, or is this not a feature that is optimised for in any of the Julia packages that use DiffEqFlux? Thank you in advance!

@mjyshin mjyshin added the question Further information is requested label Feb 18, 2025
@ChrisRackauckas
Copy link
Member

Can you share your code?

@mjyshin
Copy link
Author

mjyshin commented Feb 18, 2025

Create data

# Times
n = 100    # sample size: i ∈ 1, ..., n
tspan = (0.0f0, 10f0)
t = range(tspan[1], tspan[2], length=n)

# Initial conditions
Random.seed!(0)
m = 2    # data dimensionality j ∈ 1, ..., m
p = 256    # number of sequences k ∈ 1, ..., p
Y0 = Float32.(rand(Uniform(-0.6, 1), (m, p)))    # initial conditions

# Integrate true ODE
function truth!(du, u, p, t)
    z = u ./ (1 .+ u)
    du[1], du[2] = z[2], -z[1]
end
get_data(y0) = begin
    ode = ODEProblem(truth!, y0, tspan)
    y = solve(ode, Tsit5(), saveat=t)
    Y = Array(y)
end
Y = cat(get_data.(eachcol(Y0))..., dims=3)    # m × n × p

Create NODE

# Initial neural network
NN = Chain(Dense(m, 64, softplus), Dense(64, 64, softplus), Dense(64, m))    # (Lux) neural network NN: x ↦ ẋ
θ0, 𝒮 = Lux.setup(Xoshiro(0), NN)    # initialise parameters θ

# Instantiate NeuralODE model
function neural_ode(NN, t)
    node = NeuralODE(NN, extrema(t), Tsit5(), saveat=t, abstol=1e-9, reltol=1e-9)
end

Train

# Loss function
function L(NN, θ, 𝒮, (t, x0, y))    # Inputs: Lux model, params, state, data
    node = neural_ode(NN, t)
    x = cat(Array.(first.(node.(eachcol(x0), Ref(θ), Ref(𝒮))))..., dims=3)
    L = sum(abs2, x - y)
    L, 𝒮, NamedTuple()    # Outputs: loss, state, stats
end

# Initialise training state
opt = AdamW(5e-3)
train_state = Lux.Training.TrainState(NN, ComponentArray(θ0), 𝒮, opt)

# Train one step
traj_size = 10
idx_traj = 1:traj_size
batch_size = 32
idx_batch = randperm(p)[1:batch_size]
∇θ, loss, stats, train_state = Training.single_train_step!(
    AutoZygote(), L, (t[idx_traj], Y0[:, idx_batch], Y[:, idx_traj, idx_batch]), train_state
)

This runs (not sure if it is correct), but even only with the first 10 time steps, it takes ~5 seconds each time with a batch of 32. I don't have the code using Optimization any more, but I remember it taking a long time because I rebuilt an optimisation problem inside the for loop (over each gradient step using a new random batch of trajectories).

@avik-pal
Copy link
Member

x = cat(Array.(first.(node.(eachcol(x0), Ref(θ), Ref(𝒮))))..., dims=3)

You are broadcasting over each batch, which is expected to be slow. Instead you can pass the whole batch into node like node(x0, theta, st) and drop the cat operation

@mjyshin
Copy link
Author

mjyshin commented Feb 18, 2025

x = cat(Array.(first.(node.(eachcol(x0), Ref(θ), Ref(𝒮))))..., dims=3)

You are broadcasting over each batch, which is expected to be slow. Instead you can pass the whole batch into node like node(x0, theta, st) and drop the cat operation

I changed the loss function to:

# Loss function
function L(NN, θ, 𝒮, (t, x0, y))    # Inputs: Lux model, params, state, data
    node = neural_ode(NN, t)
    x = permutedims(Array(node(x0, θ, 𝒮)[1]), (1, 3, 2))
    L = sum(abs2, x - y)
    L, 𝒮, NamedTuple()    # Outputs: loss, state, stats
end

and it decreased the training time to ~1 second, but that's still much slower than the Diffrax example (<0.01 seconds)... Do you reckon it would be better to use Optimization with an updated loss function (but still building and solving the optimisation problem at each step)? I could make a quick test example.

@avik-pal
Copy link
Member

    node = NeuralODE(NN, extrema(t), Tsit5(), saveat=t, abstol=1e-9, reltol=1e-9)

Diffrax is using much higher tolerances

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants