2
2
3
3
Note that the differential equation solvers will run on the GPU if the initial
4
4
condition is a GPU array. Thus, for example, we can define a neural ODE manually
5
- that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU):
5
+ that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU).
6
+
7
+ For a detailed discussion on how GPUs need to be setup refer to
8
+ [ Lux Docs] ( https://lux.csail.mit.edu/stable/manual/gpu_management ) .
6
9
7
10
``` julia
8
- using DifferentialEquations, Lux, SciMLSensitivity, ComponentArrays
11
+ using DifferentialEquations, Lux, LuxCUDA, SciMLSensitivity, ComponentArrays
9
12
using Random
10
13
rng = Random. default_rng ()
11
14
15
+ const cdev = cpu_device ()
16
+ const gdev = gpu_device ()
17
+
12
18
model = Chain (Dense (2 , 50 , tanh), Dense (50 , 2 ))
13
19
ps, st = Lux. setup (rng, model)
14
- ps = ps |> ComponentArray |> gpu
15
- st = st |> gpu
20
+ ps = ps |> ComponentArray |> gdev
21
+ st = st |> gdev
16
22
dudt (u, p, t) = model (u, p, st)[1 ]
17
23
18
24
# Simulation interval and intermediary points
19
25
tspan = (0.0f0 , 10.0f0 )
20
26
tsteps = 0.0f0 : 1.0f-1 : 10.0f0
21
27
22
- u0 = Float32[2.0 ; 0.0 ] |> gpu
28
+ u0 = Float32[2.0 ; 0.0 ] |> gdev
23
29
prob_gpu = ODEProblem (dudt, u0, tspan, ps)
24
30
25
31
# Runs on a GPU
@@ -39,12 +45,10 @@ If one is using `Lux.Chain`, then the computation takes place on the GPU with
39
45
``` julia
40
46
import Lux
41
47
42
- dudt2 = Lux. Chain (x -> x .^ 3 ,
43
- Lux. Dense (2 , 50 , tanh),
44
- Lux. Dense (50 , 2 ))
48
+ dudt2 = Chain (x -> x .^ 3 , Dense (2 , 50 , tanh), Dense (50 , 2 ))
45
49
46
- u0 = Float32[2.0 ; 0.0 ] |> gpu
47
- p, st = Lux. setup (rng, dudt2) |> gpu
50
+ u0 = Float32[2.0 ; 0.0 ] |> gdev
51
+ p, st = Lux. setup (rng, dudt2) |> gdev
48
52
49
53
dudt2_ (u, p, t) = dudt2 (u, p, st)[1 ]
50
54
@@ -67,12 +71,12 @@ prob_neuralode_gpu(u0, p, st)
67
71
68
72
## Neural ODE Example
69
73
70
- Here is the full neural ODE example. Note that we use the ` gpu ` function so that the
71
- same code works on CPUs and GPUs, dependent on ` using CUDA ` .
74
+ Here is the full neural ODE example. Note that we use the ` gpu_device ` function so that the
75
+ same code works on CPUs and GPUs, dependent on ` using LuxCUDA ` .
72
76
73
77
``` julia
74
78
using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq,
75
- Plots, CUDA , SciMLSensitivity, Random, ComponentArrays
79
+ Plots, LuxCUDA , SciMLSensitivity, Random, ComponentArrays
76
80
import DiffEqFlux: NeuralODE
77
81
78
82
CUDA. allowscalar (false ) # Makes sure no slow operations are occuring
@@ -90,18 +94,18 @@ function trueODEfunc(du, u, p, t)
90
94
end
91
95
prob_trueode = ODEProblem (trueODEfunc, u0, tspan)
92
96
# Make the data into a GPU-based array if the user has a GPU
93
- ode_data = gpu (solve (prob_trueode, Tsit5 (); saveat = tsteps))
97
+ ode_data = gdev (solve (prob_trueode, Tsit5 (); saveat = tsteps))
94
98
95
99
dudt2 = Chain (x -> x .^ 3 , Dense (2 , 50 , tanh), Dense (50 , 2 ))
96
- u0 = Float32[2.0 ; 0.0 ] |> gpu
100
+ u0 = Float32[2.0 ; 0.0 ] |> gdev
97
101
p, st = Lux. setup (rng, dudt2)
98
- p = p |> ComponentArray |> gpu
99
- st = st |> gpu
102
+ p = p |> ComponentArray |> gdev
103
+ st = st |> gdev
100
104
101
105
prob_neuralode = NeuralODE (dudt2, tspan, Tsit5 (); saveat = tsteps)
102
106
103
107
function predict_neuralode (p)
104
- gpu (first (prob_neuralode (u0, p, st)))
108
+ gdev (first (prob_neuralode (u0, p, st)))
105
109
end
106
110
function loss_neuralode (p)
107
111
pred = predict_neuralode (p)
131
135
adtype = Optimization. AutoZygote ()
132
136
optf = Optimization. OptimizationFunction ((x, p) -> loss_neuralode (x), adtype)
133
137
optprob = Optimization. OptimizationProblem (optf, p)
134
- result_neuralode = Optimization. solve (optprob,
135
- Adam (0.05 );
136
- callback = callback,
137
- maxiters = 300 )
138
+ result_neuralode = Optimization. solve (optprob, Adam (0.05 ); callback, maxiters = 300 )
138
139
```
0 commit comments