Skip to content

Commit b67f684

Browse files
committed
Finish the tutorials
1 parent 29951bd commit b67f684

File tree

3 files changed

+178
-7
lines changed

3 files changed

+178
-7
lines changed

docs/src/tutorials/basic_mnist_deq.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ function train_model(solver, model_type; data_train=zip(x_train, y_train),
131131
acc = accuracy(model, data_test, ps, st) * 100
132132
@info "Starting Accuracy: $(acc)"
133133
134-
# = Uncomment these lines to enavle pretraining. See what happens
135134
@info "Pretrain with unrolling to a depth of 5"
136135
st = Lux.update_state(st, :fixed_depth, Val(5))
137136
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
@@ -146,7 +145,6 @@ function train_model(solver, model_type; data_train=zip(x_train, y_train),
146145
147146
acc = accuracy(model, data_test, ps, model_st.st) * 100
148147
@info "Pretraining complete. Accuracy: $(acc)"
149-
# =#
150148
151149
st = Lux.update_state(st, :fixed_depth, Val(0))
152150
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)

docs/src/tutorials/reduced_dim_deq.md

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,178 @@
11
# Modelling Equilibrium Models with Reduced State Size
22

3-
This Tutorial is currently under preparation. Check back soon.
3+
Sometimes we want don't want to solve a root finding problem with the full state size. This
4+
will often be faster, since the size of the root finding problem is reduced. We will use the
5+
same MNIST example as before, but this time we will use a reduced state size.
6+
7+
```@example reduced_dim_mnist
8+
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
9+
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve
10+
using MLDatasets: MNIST
11+
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
12+
13+
CUDA.allowscalar(false)
14+
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
15+
16+
const cdev = cpu_device()
17+
const gdev = gpu_device()
18+
19+
function onehot(labels_raw)
20+
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
21+
end
22+
23+
function loadmnist(batchsize, split)
24+
# Load MNIST
25+
mnist = MNIST(; split)
26+
imgs, labels_raw = mnist.features, mnist.targets
27+
# Process images into (H,W,C,BS) batches
28+
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |>
29+
gdev
30+
x_train = batchview(x_train, batchsize)
31+
# Onehot and batch the labels
32+
y_train = onehot(labels_raw) |> gdev
33+
y_train = batchview(y_train, batchsize)
34+
return x_train, y_train
35+
end
36+
37+
x_train, y_train = loadmnist(128, :train);
38+
x_test, y_test = loadmnist(128, :test);
39+
```
40+
41+
Now we will define the construct model function. Here we will use Dense Layers and
42+
downsample the features using the `init` kwarg.
43+
44+
```@example reduced_dim_mnist
45+
function construct_model(solver; model_type::Symbol=:regdeq)
46+
down = Chain(FlattenLayer(), Dense(784 => 512, gelu))
47+
48+
# The input layer of the DEQ
49+
deq_model = Chain(Parallel(+,
50+
Dense(128 => 64, tanh), # Reduced dim of `128`
51+
Dense(512 => 64, tanh)), # Original dim of `512`
52+
Dense(64 => 64, tanh), Dense(64 => 128)) # Return the reduced dim of `128`
53+
54+
if model_type === :skipdeq
55+
init = Dense(512 => 128, tanh)
56+
elseif model_type === :regdeq
57+
error(":regdeq is not supported for reduced dim models")
58+
else
59+
# This should preferably done via `ChainRulesCore.@ignore_derivatives`. But here
60+
# we are only using Zygote so this is fine.
61+
init = WrappedFunction(x -> Zygote.@ignore(fill!(similar(x, 128, size(x, 2)),
62+
false)))
63+
end
64+
65+
deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
66+
linsolve_kwargs=(; maxiters=10))
67+
68+
classifier = Chain(Dense(128 => 128, gelu), Dense(128, 10))
69+
70+
model = Chain(; down, deq, classifier)
71+
72+
# For NVIDIA GPUs this directly generates the parameters on the GPU
73+
rng = Random.default_rng() |> gdev
74+
ps, st = Lux.setup(rng, model)
75+
76+
# Warmup the forward and backward passes
77+
x = randn(rng, Float32, 28, 28, 1, 128)
78+
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
79+
80+
model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
81+
@info "warming up forward pass"
82+
logitcrossentropy(model_, x, ps, y)
83+
@info "warming up backward pass"
84+
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
85+
@info "warmup complete"
86+
87+
return model, ps, st
88+
end
89+
```
90+
91+
Define some helper functions to train the model.
92+
93+
```@example reduced_dim_mnist
94+
logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1))
95+
function logitcrossentropy(model, x, ps, y)
96+
l1 = logitcrossentropy(model(x, ps), y)
97+
# Add in some regularization
98+
l2 = mean(abs2, model.st.deq.solution.z_star .- model.st.deq.solution.u0)
99+
return l1 + 0.1f0 * l2
100+
end
101+
102+
classify(x) = argmax.(eachcol(x))
103+
104+
function accuracy(model, data, ps, st)
105+
total_correct, total = 0, 0
106+
st = Lux.testmode(st)
107+
model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
108+
for (x, y) in data
109+
target_class = classify(cdev(y))
110+
predicted_class = classify(cdev(model(x)))
111+
total_correct += sum(target_class .== predicted_class)
112+
total += length(target_class)
113+
end
114+
return total_correct / total
115+
end
116+
117+
function train_model(solver, model_type; data_train=zip(x_train, y_train),
118+
data_test=zip(x_test, y_test))
119+
model, ps, st = construct_model(solver; model_type)
120+
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
121+
122+
@info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
123+
124+
opt_st = Optimisers.setup(Adam(0.001), ps)
125+
126+
acc = accuracy(model, data_test, ps, st) * 100
127+
@info "Starting Accuracy: $(acc)"
128+
129+
@info "Pretrain with unrolling to a depth of 5"
130+
st = Lux.update_state(st, :fixed_depth, Val(5))
131+
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
132+
133+
for (i, (x, y)) in enumerate(data_train)
134+
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
135+
Optimisers.update!(opt_st, ps, res.grad[3])
136+
if i % 50 == 1
137+
@info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
138+
end
139+
end
140+
141+
acc = accuracy(model, data_test, ps, model_st.st) * 100
142+
@info "Pretraining complete. Accuracy: $(acc)"
143+
144+
st = Lux.update_state(st, :fixed_depth, Val(0))
145+
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
146+
147+
for epoch in 1:3
148+
for (i, (x, y)) in enumerate(data_train)
149+
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
150+
Optimisers.update!(opt_st, ps, res.grad[3])
151+
if i % 50 == 1
152+
@info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
153+
end
154+
end
155+
156+
acc = accuracy(model, data_test, ps, model_st.st) * 100
157+
@info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
158+
end
159+
160+
@info "Training complete."
161+
println()
162+
163+
return model, ps, st
164+
end
165+
```
166+
167+
Now we can train our model. We can't use `:regdeq` here currently, but we will support this
168+
in the future.
169+
170+
```@example reduced_dim_mnist
171+
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq)
172+
nothing # hide
173+
```
174+
175+
```@example reduced_dim_mnist
176+
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq)
177+
nothing # hide
178+
```

test/layers.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ end
4949
z, st = model(x, ps, st)
5050

5151
opt_broken = solver isa NewtonRaphson ||
52-
solver isa SimpleLimitedMemoryBroyden ||
53-
jacobian_regularization isa AutoZygote
52+
solver isa SimpleLimitedMemoryBroyden
5453
@jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch
5554

5655
@test all(isfinite, z)
@@ -142,8 +141,7 @@ end
142141
z_ = DEQs.__flatten_vcat(z)
143142

144143
opt_broken = solver isa NewtonRaphson ||
145-
solver isa SimpleLimitedMemoryBroyden ||
146-
jacobian_regularization isa AutoZygote
144+
solver isa SimpleLimitedMemoryBroyden
147145
@jet model(x, ps, st) opt_broken=opt_broken # Broken due to nfe dynamic dispatch
148146

149147
@test all(isfinite, z_)

0 commit comments

Comments
 (0)