@@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack
4
4
5
5
``` @example basic_mnist_deq
6
6
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
7
- Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
7
+ Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf
8
8
using MLDatasets: MNIST
9
9
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
10
10
@@ -20,18 +20,6 @@ const cdev = cpu_device()
20
20
const gdev = gpu_device()
21
21
```
22
22
23
- SciMLBase introduced a warning instead of depwarn which pollutes the output. We can suppress
24
- it with the following logger
25
-
26
- ``` @example basic_mnist_deq
27
- function remove_syms_warning(log_args)
28
- return log_args.message !=
29
- "The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead."
30
- end
31
-
32
- filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger())
33
- ```
34
-
35
23
We can now construct our dataloader.
36
24
37
25
``` @example basic_mnist_deq
@@ -66,8 +54,7 @@ function construct_model(solver; model_type::Symbol=:deq)
66
54
67
55
# The input layer of the DEQ
68
56
deq_model = Chain(
69
- Parallel(+,
70
- Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
57
+ Parallel(+, Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()),
71
58
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad())),
72
59
Conv((3, 3), 64 => 64, tanh; stride=1, pad=SamePad()))
73
60
@@ -79,11 +66,11 @@ function construct_model(solver; model_type::Symbol=:deq)
79
66
init = missing
80
67
end
81
68
82
- deq = DeepEquilibriumNetwork(deq_model, solver; init, verbose=false,
83
- linsolve_kwargs=(; maxiters=10))
69
+ deq = DeepEquilibriumNetwork(
70
+ deq_model, solver; init, verbose=false, linsolve_kwargs=(; maxiters=10))
84
71
85
- classifier = Chain(GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(),
86
- Dense(64, 10))
72
+ classifier = Chain(
73
+ GroupNorm(64, 64, relu), GlobalMeanPool(), FlattenLayer(), Dense(64, 10))
87
74
88
75
model = Chain(; down, deq, classifier)
89
76
@@ -95,12 +82,12 @@ function construct_model(solver; model_type::Symbol=:deq)
95
82
x = randn(rng, Float32, 28, 28, 1, 128)
96
83
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
97
84
98
- model_ = Lux.Experimental. StatefulLuxLayer(model, ps, st)
99
- @info " warming up forward pass"
85
+ model_ = StatefulLuxLayer(model, ps, st)
86
+ @printf "[%s] warming up forward pass\n" string(now())
100
87
logitcrossentropy(model_, x, ps, y)
101
- @info " warming up backward pass"
88
+ @printf "[%s] warming up backward pass\n" string(now())
102
89
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
103
- @info " warmup complete"
90
+ @printf "[%s] warmup complete\n" string(now())
104
91
105
92
return model, ps, st
106
93
end
@@ -122,7 +109,7 @@ classify(x) = argmax.(eachcol(x))
122
109
function accuracy(model, data, ps, st)
123
110
total_correct, total = 0, 0
124
111
st = Lux.testmode(st)
125
- model = Lux.Experimental. StatefulLuxLayer(model, ps, st)
112
+ model = StatefulLuxLayer(model, ps, st)
126
113
for (x, y) in data
127
114
target_class = classify(cdev(y))
128
115
predicted_class = classify(cdev(model(x)))
@@ -132,51 +119,48 @@ function accuracy(model, data, ps, st)
132
119
return total_correct / total
133
120
end
134
121
135
- function train_model(solver, model_type; data_train=zip(x_train, y_train),
136
- data_test=zip(x_test, y_test))
122
+ function train_model(
123
+ solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
137
124
model, ps, st = construct_model(solver; model_type)
138
- model_st = Lux.Experimental. StatefulLuxLayer(model, nothing, st)
125
+ model_st = StatefulLuxLayer(model, nothing, st)
139
126
140
- @info " Training Model: $(model_type) with Solver: $( nameof(typeof(solver)))"
127
+ @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))
141
128
142
129
opt_st = Optimisers.setup(Adam(0.001), ps)
143
130
144
131
acc = accuracy(model, data_test, ps, st) * 100
145
- @info " Starting Accuracy: $(acc)"
132
+ @printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc
146
133
147
- @info " Pretrain with unrolling to a depth of 5"
134
+ @printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
148
135
st = Lux.update_state(st, :fixed_depth, Val(5))
149
- model_st = Lux.Experimental. StatefulLuxLayer(model, ps, st)
136
+ model_st = StatefulLuxLayer(model, ps, st)
150
137
151
138
for (i, (x, y)) in enumerate(data_train)
152
139
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
153
140
Optimisers.update!(opt_st, ps, res.grad[3])
154
- if i % 50 == 1
155
- @info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
156
- end
141
+ i % 50 == 1 &&
142
+ @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
157
143
end
158
144
159
145
acc = accuracy(model, data_test, ps, model_st.st) * 100
160
- @info " Pretraining complete. Accuracy: $(acc)"
146
+ @printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc
161
147
162
148
st = Lux.update_state(st, :fixed_depth, Val(0))
163
- model_st = Lux.Experimental. StatefulLuxLayer(model, ps, st)
149
+ model_st = StatefulLuxLayer(model, ps, st)
164
150
165
151
for epoch in 1:3
166
152
for (i, (x, y)) in enumerate(data_train)
167
153
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
168
154
Optimisers.update!(opt_st, ps, res.grad[3])
169
- if i % 50 == 1
170
- @info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
171
- end
155
+ i % 50 == 1 &&
156
+ @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
172
157
end
173
158
174
159
acc = accuracy(model, data_test, ps, model_st.st) * 100
175
- @info " Epoch: [$(epoch)/3 ] Accuracy: $(acc)"
160
+ @printf "[%s] Epoch: [%d/%d ] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
176
161
end
177
162
178
- @info "Training complete."
179
- println()
163
+ @printf "[%s] Training complete.\n" string(now())
180
164
181
165
return model, ps, st
182
166
end
@@ -188,19 +172,15 @@ and end up using solvers like `Broyden`, but we can simply slap in any of the fa
188
172
from NonlinearSolve.jl. Here we will use Newton-Krylov Method:
189
173
190
174
``` @example basic_mnist_deq
191
- with_logger(filtered_logger) do
192
- train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq)
193
- end
175
+ train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq);
194
176
nothing # hide
195
177
```
196
178
197
179
We can also train a continuous DEQ by passing in an ODE solver. Here we will use ` VCAB3() `
198
180
which tend to be quite fast for continuous Neural Network problems.
199
181
200
182
``` @example basic_mnist_deq
201
- with_logger(filtered_logger) do
202
- train_model(VCAB3(), :deq)
203
- end
183
+ train_model(VCAB3(), :deq);
204
184
nothing # hide
205
185
```
206
186
0 commit comments