Skip to content

Commit a072fd8

Browse files
committed
Fix remaining docs
1 parent 29ccaf7 commit a072fd8

9 files changed

+443
-684
lines changed

Manifest.toml

+3-5
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,6 @@ version = "1.0.5+1"
172172
[[deps.ComponentArrays]]
173173
deps = ["ArrayInterface", "ChainRulesCore", "ForwardDiff", "Functors", "LinearAlgebra", "PackageExtensionCompat", "StaticArrayInterface", "StaticArraysCore"]
174174
git-tree-sha1 = "d30eb4d89c791a64e698546c1e0e0e488cd99da5"
175-
repo-rev = "ap/patch_diffeqflux"
176-
repo-url = "https://github.com/avik-pal/ComponentArrays.jl"
177175
uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
178176
version = "0.15.5"
179177
weakdeps = ["Adapt", "ConstructionBase", "GPUArrays", "RecursiveArrayTools", "ReverseDiff", "SciMLBase", "Tracker", "Zygote"]
@@ -613,9 +611,9 @@ version = "0.4.1"
613611

614612
[[deps.KernelAbstractions]]
615613
deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"]
616-
git-tree-sha1 = "95063c5bc98ba0c47e75e05ae71f1fed4deac6f6"
614+
git-tree-sha1 = "b0737cbbe1c8da6f1139d1c23e35e7cea129c0af"
617615
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
618-
version = "0.9.12"
616+
version = "0.9.13"
619617
weakdeps = ["EnzymeCore"]
620618

621619
[deps.KernelAbstractions.extensions]
@@ -1484,7 +1482,7 @@ version = "0.5.23"
14841482

14851483
[[deps.Tracker]]
14861484
deps = ["Adapt", "DiffRules", "ForwardDiff", "Functors", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Optimisers", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"]
1487-
git-tree-sha1 = "752daa5bbd9721b0566e39cdc75cffdc3ef5593d"
1485+
git-tree-sha1 = "1ae02dde414ae2edbafbab600f23aad212f2f098"
14881486
repo-rev = "ap/ambiguous"
14891487
repo-url = "https://github.com/avik-pal/Tracker.jl"
14901488
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

docs/Project.toml

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
3+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
34
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
45
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
56
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
@@ -13,21 +14,26 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1314
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
17+
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
1618
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
1719
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
20+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
21+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1822
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1923
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
2024
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
2125
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
2226
OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282"
2327
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2428
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
29+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2530
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2631
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2732
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2833
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2934
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3035
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
36+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3137
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3238

3339
[compat]

docs/src/examples/GPUs.md

+23-22
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,30 @@
22

33
Note that the differential equation solvers will run on the GPU if the initial
44
condition is a GPU array. Thus, for example, we can define a neural ODE manually
5-
that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU):
5+
that runs on the GPU (if no GPU is available, the calculation defaults back to the CPU).
6+
7+
For a detailed discussion on how GPUs need to be setup refer to
8+
[Lux Docs](https://lux.csail.mit.edu/stable/manual/gpu_management).
69

710
```julia
8-
using DifferentialEquations, Lux, SciMLSensitivity, ComponentArrays
11+
using DifferentialEquations, Lux, LuxCUDA, SciMLSensitivity, ComponentArrays
912
using Random
1013
rng = Random.default_rng()
1114

15+
const cdev = cpu_device()
16+
const gdev = gpu_device()
17+
1218
model = Chain(Dense(2, 50, tanh), Dense(50, 2))
1319
ps, st = Lux.setup(rng, model)
14-
ps = ps |> ComponentArray |> gpu
15-
st = st |> gpu
20+
ps = ps |> ComponentArray |> gdev
21+
st = st |> gdev
1622
dudt(u, p, t) = model(u, p, st)[1]
1723

1824
# Simulation interval and intermediary points
1925
tspan = (0.0f0, 10.0f0)
2026
tsteps = 0.0f0:1.0f-1:10.0f0
2127

22-
u0 = Float32[2.0; 0.0] |> gpu
28+
u0 = Float32[2.0; 0.0] |> gdev
2329
prob_gpu = ODEProblem(dudt, u0, tspan, ps)
2430

2531
# Runs on a GPU
@@ -39,12 +45,10 @@ If one is using `Lux.Chain`, then the computation takes place on the GPU with
3945
```julia
4046
import Lux
4147

42-
dudt2 = Lux.Chain(x -> x .^ 3,
43-
Lux.Dense(2, 50, tanh),
44-
Lux.Dense(50, 2))
48+
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
4549

46-
u0 = Float32[2.0; 0.0] |> gpu
47-
p, st = Lux.setup(rng, dudt2) |> gpu
50+
u0 = Float32[2.0; 0.0] |> gdev
51+
p, st = Lux.setup(rng, dudt2) |> gdev
4852

4953
dudt2_(u, p, t) = dudt2(u, p, st)[1]
5054

@@ -67,12 +71,12 @@ prob_neuralode_gpu(u0, p, st)
6771

6872
## Neural ODE Example
6973

70-
Here is the full neural ODE example. Note that we use the `gpu` function so that the
71-
same code works on CPUs and GPUs, dependent on `using CUDA`.
74+
Here is the full neural ODE example. Note that we use the `gpu_device` function so that the
75+
same code works on CPUs and GPUs, dependent on `using LuxCUDA`.
7276

7377
```julia
7478
using Lux, Optimization, OptimizationOptimisers, Zygote, OrdinaryDiffEq,
75-
Plots, CUDA, SciMLSensitivity, Random, ComponentArrays
79+
Plots, LuxCUDA, SciMLSensitivity, Random, ComponentArrays
7680
import DiffEqFlux: NeuralODE
7781

7882
CUDA.allowscalar(false) # Makes sure no slow operations are occuring
@@ -90,18 +94,18 @@ function trueODEfunc(du, u, p, t)
9094
end
9195
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
9296
# Make the data into a GPU-based array if the user has a GPU
93-
ode_data = gpu(solve(prob_trueode, Tsit5(); saveat = tsteps))
97+
ode_data = gdev(solve(prob_trueode, Tsit5(); saveat = tsteps))
9498

9599
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
96-
u0 = Float32[2.0; 0.0] |> gpu
100+
u0 = Float32[2.0; 0.0] |> gdev
97101
p, st = Lux.setup(rng, dudt2)
98-
p = p |> ComponentArray |> gpu
99-
st = st |> gpu
102+
p = p |> ComponentArray |> gdev
103+
st = st |> gdev
100104

101105
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps)
102106

103107
function predict_neuralode(p)
104-
gpu(first(prob_neuralode(u0, p, st)))
108+
gdev(first(prob_neuralode(u0, p, st)))
105109
end
106110
function loss_neuralode(p)
107111
pred = predict_neuralode(p)
@@ -131,8 +135,5 @@ end
131135
adtype = Optimization.AutoZygote()
132136
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
133137
optprob = Optimization.OptimizationProblem(optf, p)
134-
result_neuralode = Optimization.solve(optprob,
135-
Adam(0.05);
136-
callback = callback,
137-
maxiters = 300)
138+
result_neuralode = Optimization.solve(optprob, Adam(0.05); callback, maxiters = 300)
138139
```

0 commit comments

Comments
 (0)