@@ -4,9 +4,9 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack
4
4
5
5
``` @example basic_mnist_deq
6
6
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
7
- Statistics, Random, Optimization, OptimizationOptimisers, LuxCUDA
7
+ Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve
8
8
using MLDatasets: MNIST
9
- using MLDataUtils: LabelEnc, convertlabel, stratifiedobs
9
+ using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
10
10
11
11
CUDA.allowscalar(false)
12
12
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
@@ -27,9 +27,9 @@ function onehot(labels_raw)
27
27
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
28
28
end
29
29
30
- function loadmnist(batchsize)
30
+ function loadmnist(batchsize, split )
31
31
# Load MNIST
32
- mnist = MNIST(; split=:train )
32
+ mnist = MNIST(; split)
33
33
imgs, labels_raw = mnist.features, mnist.targets
34
34
# Process images into (H,W,C,BS) batches
35
35
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |>
@@ -40,4 +40,156 @@ function loadmnist(batchsize)
40
40
y_train = batchview(y_train, batchsize)
41
41
return x_train, y_train
42
42
end
43
+
44
+ x_train, y_train = loadmnist(128, :train);
45
+ x_test, y_test = loadmnist(128, :test);
43
46
```
47
+
48
+ Construct the Lux Neural Network containing a DEQ layer.
49
+
50
+ ``` @example basic_mnist_deq
51
+ function construct_model(solver; model_type::Symbol=:deq)
52
+ down = Chain(Conv((3, 3), 1 => 64, gelu; stride=1), GroupNorm(64, 64),
53
+ Conv((4, 4), 64 => 64; stride=2, pad=1))
54
+
55
+ # The input layer of the DEQ
56
+ deq_model = Chain(Parallel(+,
57
+ Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
58
+ Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())),
59
+ Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()))
60
+
61
+ if model_type === :skipdeq
62
+ init = Conv((3, 3), 64 => 64, gelu; stride=1, pad=SamePad())
63
+ elseif model_type === :regdeq
64
+ init = nothing
65
+ else
66
+ init = missing
67
+ end
68
+
69
+ deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
70
+ linsolve_kwargs=(; maxiters=10))
71
+
72
+ classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(),
73
+ Dense(64, 10))
74
+
75
+ model = Chain(; down, deq, classifier)
76
+
77
+ # For NVIDIA GPUs this directly generates the parameters on the GPU
78
+ rng = Random.default_rng() |> gdev
79
+ ps, st = Lux.setup(rng, model)
80
+
81
+ # Warmup the forward and backward passes
82
+ x = randn(rng, Float32, 28, 28, 1, 128)
83
+ y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
84
+
85
+ model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
86
+ @info "warming up forward pass"
87
+ logitcrossentropy(model_, x, ps, y)
88
+ @info "warming up backward pass"
89
+ Zygote.gradient(logitcrossentropy, model_, x, ps, y)
90
+ @info "warmup complete"
91
+
92
+ return model, ps, st
93
+ end
94
+ ```
95
+
96
+ Define some helper functions to train the model.
97
+
98
+ ``` @example basic_mnist_deq
99
+ logitcrossentropy(ŷ, y) = mean(-sum(y .* logsoftmax(ŷ; dims=1); dims=1))
100
+ function logitcrossentropy(model, x, ps, y)
101
+ l1 = logitcrossentropy(model(x, ps), y)
102
+ # Add in some regularization
103
+ l2 = mean(abs2, model.st.deq.solution.z_star .- model.st.deq.solution.u0)
104
+ return l1 + 10.0 * l2
105
+ end
106
+
107
+ classify(x) = argmax.(eachcol(x))
108
+
109
+ function accuracy(model, data, ps, st)
110
+ total_correct, total = 0, 0
111
+ st = Lux.testmode(st)
112
+ model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
113
+ for (x, y) in data
114
+ target_class = classify(cdev(y))
115
+ predicted_class = classify(cdev(model(x)))
116
+ total_correct += sum(target_class .== predicted_class)
117
+ total += length(target_class)
118
+ end
119
+ return total_correct / total
120
+ end
121
+
122
+ function train_model(solver, model_type; data_train=zip(x_train, y_train),
123
+ data_test=zip(x_test, y_test))
124
+ model, ps, st = construct_model(solver; model_type)
125
+ model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
126
+
127
+ @info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
128
+
129
+ opt_st = Optimisers.setup(Adam(0.001), ps)
130
+
131
+ acc = accuracy(model, data_test, ps, st) * 100
132
+ @info "Starting Accuracy: $(acc)"
133
+
134
+ # = Uncomment these lines to enavle pretraining. See what happens
135
+ @info "Pretrain with unrolling to a depth of 5"
136
+ st = Lux.update_state(st, :fixed_depth, Val(5))
137
+ model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
138
+
139
+ for (i, (x, y)) in enumerate(data_train)
140
+ res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
141
+ Optimisers.update!(opt_st, ps, res.grad[3])
142
+ if i % 50 == 1
143
+ @info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
144
+ end
145
+ end
146
+
147
+ acc = accuracy(model, data_test, ps, model_st.st) * 100
148
+ @info "Pretraining complete. Accuracy: $(acc)"
149
+ # =#
150
+
151
+ st = Lux.update_state(st, :fixed_depth, Val(0))
152
+ model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
153
+
154
+ for epoch in 1:3
155
+ for (i, (x, y)) in enumerate(data_train)
156
+ res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
157
+ Optimisers.update!(opt_st, ps, res.grad[3])
158
+ if i % 50 == 1
159
+ @info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
160
+ end
161
+ end
162
+
163
+ acc = accuracy(model, data_test, ps, model_st.st) * 100
164
+ @info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
165
+ end
166
+
167
+ @info "Training complete."
168
+ println()
169
+
170
+ return model, ps, st
171
+ end
172
+ ```
173
+
174
+ Now we can train our model. First we will train a Discrete DEQ, which effectively means
175
+ pass in a root finding algorithm. Typically most packages lack good nonlinear solvers,
176
+ and end up using solvers like ` Broyden ` , but we can simply slap in any of the fancy solvers
177
+ from NonlinearSolve.jl. Here we will use Newton-Krylov Method:
178
+
179
+ ``` @example basic_mnist_deq
180
+ train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq)
181
+ nothing # hide
182
+ ```
183
+
184
+ We can also train a continuous DEQ by passing in an ODE solver. Here we will use ` VCAB3() `
185
+ which tend to be quite fast for continuous Neural Network problems.
186
+
187
+ ``` @example basic_mnist_deq
188
+ train_model(VCAB3(), :deq)
189
+ nothing # hide
190
+ ```
191
+
192
+ This code is setup to allow playing around with different DEQ models. Try modifying the
193
+ ` model_type ` argument to ` train_model ` to ` :skipdeq ` or ` :deq ` to see how the model
194
+ behaves. You can also try different solvers from NonlinearSolve.jl and OrdinaryDiffEq.jl!
195
+ Even 3rd party solvers from Sundials.jl will work, just remember to use CPU for those.
0 commit comments