22
33Note that the differential equation solvers will run on the GPU if the initial
44condition is a GPU array. Thus, for example, we can define a neural ODE manually
5- that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU):
5+ that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU).
6+
7+ For a detailed discussion on how GPUs need to be setup refer to
8+ [ Lux Docs] ( https://lux.csail.mit.edu/stable/manual/gpu_management ) .
69
710``` julia
8- using DifferentialEquations, Lux, SciMLSensitivity, ComponentArrays
11+ using DifferentialEquations, Lux, LuxCUDA, SciMLSensitivity, ComponentArrays
912using Random
1013rng = Random. default_rng()
1114
15+ const cdev = cpu_device()
16+ const gdev = gpu_device()
17+
1218model = Chain(Dense(2 , 50 , tanh), Dense(50 , 2 ))
1319ps, st = Lux. setup(rng, model)
14- ps = ps |> ComponentArray |> gpu
15- st = st |> gpu
20+ ps = ps |> ComponentArray |> gdev
21+ st = st |> gdev
1622dudt(u, p, t) = model(u, p, st)[1 ]
1723
1824# Simulation interval and intermediary points
1925tspan = (0.0f0 , 10.0f0 )
2026tsteps = 0.0f0 : 1.0f-1 : 10.0f0
2127
22- u0 = Float32[2.0 ; 0.0 ] |> gpu
28+ u0 = Float32[2.0 ; 0.0 ] |> gdev
2329prob_gpu = ODEProblem(dudt, u0, tspan, ps)
2430
2531# Runs on a GPU
@@ -39,12 +45,10 @@ If one is using `Lux.Chain`, then the computation takes place on the GPU with
3945``` julia
4046import Lux
4147
42- dudt2 = Lux. Chain(x -> x .^ 3 ,
43- Lux. Dense(2 , 50 , tanh),
44- Lux. Dense(50 , 2 ))
48+ dudt2 = Chain(x -> x .^ 3 , Dense(2 , 50 , tanh), Dense(50 , 2 ))
4549
46- u0 = Float32[2.0 ; 0.0 ] |> gpu
47- p, st = Lux. setup(rng, dudt2) |> gpu
50+ u0 = Float32[2.0 ; 0.0 ] |> gdev
51+ p, st = Lux. setup(rng, dudt2) |> gdev
4852
4953dudt2_(u, p, t) = dudt2(u, p, st)[1 ]
5054
@@ -67,12 +71,12 @@ prob_neuralode_gpu(u0, p, st)
6771
6872## Neural ODE Example
6973
70- Here is the full neural ODE example. Note that we use the ` gpu ` function so that the
71- same code works on CPUs and GPUs, dependent on ` using CUDA ` .
74+ Here is the full neural ODE example. Note that we use the ` gpu_device ` function so that the
75+ same code works on CPUs and GPUs, dependent on ` using LuxCUDA ` .
7276
7377``` julia
7478using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq,
75- Plots, CUDA , SciMLSensitivity, Random, ComponentArrays
79+ Plots, LuxCUDA , SciMLSensitivity, Random, ComponentArrays
7680import DiffEqFlux: NeuralODE
7781
7882CUDA. allowscalar(false ) # Makes sure no slow operations are occuring
@@ -90,18 +94,18 @@ function trueODEfunc(du, u, p, t)
9094end
9195prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
9296# Make the data into a GPU-based array if the user has a GPU
93- ode_data = gpu (solve(prob_trueode, Tsit5(); saveat = tsteps))
97+ ode_data = gdev (solve(prob_trueode, Tsit5(); saveat = tsteps))
9498
9599dudt2 = Chain(x -> x .^ 3 , Dense(2 , 50 , tanh), Dense(50 , 2 ))
96- u0 = Float32[2.0 ; 0.0 ] |> gpu
100+ u0 = Float32[2.0 ; 0.0 ] |> gdev
97101p, st = Lux. setup(rng, dudt2)
98- p = p |> ComponentArray |> gpu
99- st = st |> gpu
102+ p = p |> ComponentArray |> gdev
103+ st = st |> gdev
100104
101105prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
102106
103107function predict_neuralode(p)
104- gpu (first(prob_neuralode(u0, p, st)))
108+ gdev (first(prob_neuralode(u0, p, st)))
105109end
106110function loss_neuralode(p)
107111 pred = predict_neuralode(p)
131135adtype = Optimization. AutoZygote()
132136optf = Optimization. OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
133137optprob = Optimization. OptimizationProblem(optf, p)
134- result_neuralode = Optimization. solve(optprob,
135- Adam(0.05 );
136- callback = callback,
137- maxiters = 300 )
138+ result_neuralode = Optimization. solve(optprob, Adam(0.05 ); callback, maxiters = 300 )
138139```
0 commit comments