-
-
Notifications
You must be signed in to change notification settings - Fork 158
Batching at every gradient step #967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Can you share your code? |
Create data# Times
n = 100 # sample size: i ∈ 1, ..., n
tspan = (0.0f0, 10f0)
t = range(tspan[1], tspan[2], length=n)
# Initial conditions
Random.seed!(0)
m = 2 # data dimensionality j ∈ 1, ..., m
p = 256 # number of sequences k ∈ 1, ..., p
Y0 = Float32.(rand(Uniform(-0.6, 1), (m, p))) # initial conditions
# Integrate true ODE
function truth!(du, u, p, t)
z = u ./ (1 .+ u)
du[1], du[2] = z[2], -z[1]
end
get_data(y0) = begin
ode = ODEProblem(truth!, y0, tspan)
y = solve(ode, Tsit5(), saveat=t)
Y = Array(y)
end
Y = cat(get_data.(eachcol(Y0))..., dims=3) # m × n × p Create NODE# Initial neural network
NN = Chain(Dense(m, 64, softplus), Dense(64, 64, softplus), Dense(64, m)) # (Lux) neural network NN: x ↦ ẋ
θ0, 𝒮 = Lux.setup(Xoshiro(0), NN) # initialise parameters θ
# Instantiate NeuralODE model
function neural_ode(NN, t)
node = NeuralODE(NN, extrema(t), Tsit5(), saveat=t, abstol=1e-9, reltol=1e-9)
end Train# Loss function
function L(NN, θ, 𝒮, (t, x0, y)) # Inputs: Lux model, params, state, data
node = neural_ode(NN, t)
x = cat(Array.(first.(node.(eachcol(x0), Ref(θ), Ref(𝒮))))..., dims=3)
L = sum(abs2, x - y)
L, 𝒮, NamedTuple() # Outputs: loss, state, stats
end
# Initialise training state
opt = AdamW(5e-3)
train_state = Lux.Training.TrainState(NN, ComponentArray(θ0), 𝒮, opt)
# Train one step
traj_size = 10
idx_traj = 1:traj_size
batch_size = 32
idx_batch = randperm(p)[1:batch_size]
∇θ, loss, stats, train_state = Training.single_train_step!(
AutoZygote(), L, (t[idx_traj], Y0[:, idx_batch], Y[:, idx_traj, idx_batch]), train_state
) This runs (not sure if it is correct), but even only with the first 10 time steps, it takes ~5 seconds each time with a batch of 32. I don't have the code using Optimization any more, but I remember it taking a long time because I rebuilt an optimisation problem inside the for loop (over each gradient step using a new random batch of trajectories). |
You are broadcasting over each batch, which is expected to be slow. Instead you can pass the whole batch into |
I changed the loss function to: # Loss function
function L(NN, θ, 𝒮, (t, x0, y)) # Inputs: Lux model, params, state, data
node = neural_ode(NN, t)
x = permutedims(Array(node(x0, θ, 𝒮)[1]), (1, 3, 2))
L = sum(abs2, x - y)
L, 𝒮, NamedTuple() # Outputs: loss, state, stats
end and it decreased the training time to ~1 second, but that's still much slower than the Diffrax example (<0.01 seconds)... Do you reckon it would be better to use Optimization with an updated loss function (but still building and solving the optimisation problem at each step)? I could make a quick test example. |
Diffrax is using much higher tolerances |
I am new to neural differential equations and have been going through some tutorials to better understand them. I noticed that in Python's Diffrax tutorial, they use a batching scheme for training, where every gradient step seems to be using 32 trajectories. This runs surprisingly fast, and when I tried to implement this in Julia, either via Optimization (setting
maxiters=1
insolve
) or via Lux.Training directly, it takes forever.Am I totally misunderstanding something from the tutorial, or is this not a feature that is optimised for in any of the Julia packages that use DiffEqFlux? Thank you in advance!
The text was updated successfully, but these errors were encountered: