Skip to content

Commit 29951bd

Browse files
committed
Finish the basic tutorial
1 parent c4f0d10 commit 29951bd

File tree

4 files changed

+163
-11
lines changed

4 files changed

+163
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ ConcreteStructs = "0.2"
3232
ConstructionBase = "1"
3333
DiffEqBase = "6.119"
3434
LinearAlgebra = "1"
35-
Lux = "0.5.7"
35+
Lux = "0.5.11"
3636
Random = "1"
3737
SciMLBase = "2"
3838
SciMLSensitivity = "7.43"

docs/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
5+
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
56
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
67
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
78
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
89
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
910
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
10-
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
11-
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
11+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1212
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1313
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1414
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"

docs/src/tutorials/basic_mnist_deq.md

Lines changed: 156 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack
44

55
```@example basic_mnist_deq
66
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
7-
Statistics, Random, Optimization, OptimizationOptimisers, LuxCUDA
7+
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve
88
using MLDatasets: MNIST
9-
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs
9+
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
1010
1111
CUDA.allowscalar(false)
1212
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
@@ -27,9 +27,9 @@ function onehot(labels_raw)
2727
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
2828
end
2929
30-
function loadmnist(batchsize)
30+
function loadmnist(batchsize, split)
3131
# Load MNIST
32-
mnist = MNIST(; split=:train)
32+
mnist = MNIST(; split)
3333
imgs, labels_raw = mnist.features, mnist.targets
3434
# Process images into (H,W,C,BS) batches
3535
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |>
@@ -40,4 +40,156 @@ function loadmnist(batchsize)
4040
y_train = batchview(y_train, batchsize)
4141
return x_train, y_train
4242
end
43+
44+
x_train, y_train = loadmnist(128, :train);
45+
x_test, y_test = loadmnist(128, :test);
4346
```
47+
48+
Construct the Lux Neural Network containing a DEQ layer.
49+
50+
```@example basic_mnist_deq
51+
function construct_model(solver; model_type::Symbol=:deq)
52+
down = Chain(Conv((3, 3), 1 => 64, gelu; stride=1), GroupNorm(64, 64),
53+
Conv((4, 4), 64 => 64; stride=2, pad=1))
54+
55+
# The input layer of the DEQ
56+
deq_model = Chain(Parallel(+,
57+
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
58+
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())),
59+
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()))
60+
61+
if model_type === :skipdeq
62+
init = Conv((3, 3), 64 => 64, gelu; stride=1, pad=SamePad())
63+
elseif model_type === :regdeq
64+
init = nothing
65+
else
66+
init = missing
67+
end
68+
69+
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
70+
linsolve_kwargs=(; maxiters=10))
71+
72+
classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(),
73+
Dense(64, 10))
74+
75+
model = Chain(; down, deq, classifier)
76+
77+
# For NVIDIA GPUs this directly generates the parameters on the GPU
78+
rng = Random.default_rng() |> gdev
79+
ps, st = Lux.setup(rng, model)
80+
81+
# Warmup the forward and backward passes
82+
x = randn(rng, Float32, 28, 28, 1, 128)
83+
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
84+
85+
model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
86+
@info "warming up forward pass"
87+
logitcrossentropy(model_, x, ps, y)
88+
@info "warming up backward pass"
89+
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
90+
@info "warmup complete"
91+
92+
return model, ps, st
93+
end
94+
```
95+
96+
Define some helper functions to train the model.
97+
98+
```@example basic_mnist_deq
99+
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1))
100+
function logitcrossentropy(model, x, ps, y)
101+
l1 = logitcrossentropy(model(x, ps), y)
102+
# Add in some regularization
103+
l2 = mean(abs2, model.st.deq.solution.z_star .- model.st.deq.solution.u0)
104+
return l1 + 10.0 * l2
105+
end
106+
107+
classify(x) = argmax.(eachcol(x))
108+
109+
function accuracy(model, data, ps, st)
110+
total_correct, total = 0, 0
111+
st = Lux.testmode(st)
112+
model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
113+
for (x, y) in data
114+
target_class = classify(cdev(y))
115+
predicted_class = classify(cdev(model(x)))
116+
total_correct += sum(target_class .== predicted_class)
117+
total += length(target_class)
118+
end
119+
return total_correct / total
120+
end
121+
122+
function train_model(solver, model_type; data_train=zip(x_train, y_train),
123+
data_test=zip(x_test, y_test))
124+
model, ps, st = construct_model(solver; model_type)
125+
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
126+
127+
@info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
128+
129+
opt_st = Optimisers.setup(Adam(0.001), ps)
130+
131+
acc = accuracy(model, data_test, ps, st) * 100
132+
@info "Starting Accuracy: $(acc)"
133+
134+
# = Uncomment these lines to enavle pretraining. See what happens
135+
@info "Pretrain with unrolling to a depth of 5"
136+
st = Lux.update_state(st, :fixed_depth, Val(5))
137+
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
138+
139+
for (i, (x, y)) in enumerate(data_train)
140+
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
141+
Optimisers.update!(opt_st, ps, res.grad[3])
142+
if i % 50 == 1
143+
@info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
144+
end
145+
end
146+
147+
acc = accuracy(model, data_test, ps, model_st.st) * 100
148+
@info "Pretraining complete. Accuracy: $(acc)"
149+
# =#
150+
151+
st = Lux.update_state(st, :fixed_depth, Val(0))
152+
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
153+
154+
for epoch in 1:3
155+
for (i, (x, y)) in enumerate(data_train)
156+
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
157+
Optimisers.update!(opt_st, ps, res.grad[3])
158+
if i % 50 == 1
159+
@info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
160+
end
161+
end
162+
163+
acc = accuracy(model, data_test, ps, model_st.st) * 100
164+
@info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
165+
end
166+
167+
@info "Training complete."
168+
println()
169+
170+
return model, ps, st
171+
end
172+
```
173+
174+
Now we can train our model. First we will train a Discrete DEQ, which effectively means
175+
pass in a root finding algorithm. Typically most packages lack good nonlinear solvers,
176+
and end up using solvers like `Broyden`, but we can simply slap in any of the fancy solvers
177+
from NonlinearSolve.jl. Here we will use Newton-Krylov Method:
178+
179+
```@example basic_mnist_deq
180+
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq)
181+
nothing # hide
182+
```
183+
184+
We can also train a continuous DEQ by passing in an ODE solver. Here we will use `VCAB3()`
185+
which tend to be quite fast for continuous Neural Network problems.
186+
187+
```@example basic_mnist_deq
188+
train_model(VCAB3(), :deq)
189+
nothing # hide
190+
```
191+
192+
This code is setup to allow playing around with different DEQ models. Try modifying the
193+
`model_type` argument to `train_model` to `:skipdeq` or `:deq` to see how the model
194+
behaves. You can also try different solvers from NonlinearSolve.jl and OrdinaryDiffEq.jl!
195+
Even 3rd party solvers from Sundials.jl will work, just remember to use CPU for those.

src/layers.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Stores the solution of a DeepEquilibriumNetwork and its variants.
1313
- `nfe`: Number of Function Evaluations
1414
- `original`: Original Internal Solution
1515
"""
16-
@concrete struct DeepEquilibriumSolution
16+
struct DeepEquilibriumSolution # This is intentionally left untyped to allow updating `st`
1717
z_star
1818
u0
1919
residual
@@ -85,7 +85,7 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true})
8585
model, ps.model, zˢᵗᵃʳ, x, rng)
8686

8787
solution = DeepEquilibriumSolution(zˢᵗᵃʳ, z, resid, zero(eltype(x)),
88-
_unwrap_val(st.fixed_depth), nothing)
88+
_unwrap_val(st.fixed_depth), jac_loss)
8989
res = __split_and_reshape(zˢᵗᵃʳ, __getproperty(deq.model, Val(:split_idxs)),
9090
__getproperty(deq.model, Val(:scales)))
9191

@@ -102,7 +102,7 @@ function (deq::DEQ{pType})(x, ps, st::NamedTuple, ::Val{false}) where {pType}
102102
prob = __construct_prob(pType, ODEFunction{false}(dudt), z, (; ps=ps.model, x))
103103
alg = __normalize_alg(deq)
104104
sol = solve(prob, alg; sensealg=__default_sensealg(prob), abstol=1e-3, reltol=1e-3,
105-
termination_condition=AbsNormTerminationMode(), maxiters=100, deq.kwargs...)
105+
termination_condition=AbsNormTerminationMode(), maxiters=32, deq.kwargs...)
106106
zˢᵗᵃʳ = __get_steady_state(sol)
107107

108108
rng = Lux.replicate(st.rng)
@@ -148,7 +148,7 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing]
148148
julia> using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
149149
150150
julia> model = DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; use_bias=false),
151-
Dense(2, 2; use_bias=false)), VCABM3())
151+
Dense(2, 2; use_bias=false)), VCABM3(); verbose=false)
152152
DeepEquilibriumNetwork(
153153
model = Parallel(
154154
+

0 commit comments

Comments
 (0)