@@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack
4
4
5
5
``` @example basic_mnist_deq
6
6
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
7
- Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
7
+ Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf
8
8
using MLDatasets: MNIST
9
9
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
10
10
@@ -20,18 +20,6 @@ const cdev = cpu_device()
20
20
const gdev = gpu_device()
21
21
```
22
22
23
- SciMLBase introduced a warning instead of depwarn which pollutes the output. We can suppress
24
- it with the following logger
25
-
26
- ``` @example basic_mnist_deq
27
- function remove_syms_warning(log_args)
28
- return log_args.message !=
29
- "The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead."
30
- end
31
-
32
- filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger())
33
- ```
34
-
35
23
We can now construct our dataloader.
36
24
37
25
``` @example basic_mnist_deq
@@ -94,12 +82,12 @@ function construct_model(solver; model_type::Symbol=:deq)
94
82
x = randn(rng, Float32, 28, 28, 1, 128)
95
83
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
96
84
97
- model_ = Lux.Experimental. StatefulLuxLayer(model, ps, st)
98
- @info " warming up forward pass"
85
+ model_ = StatefulLuxLayer(model, ps, st)
86
+ @printf "[%s] warming up forward pass\n" string(now())
99
87
logitcrossentropy(model_, x, ps, y)
100
- @info " warming up backward pass"
88
+ @printf "[%s] warming up backward pass\n" string(now())
101
89
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
102
- @info " warmup complete"
90
+ @printf "[%s] warmup complete\n" string(now())
103
91
104
92
return model, ps, st
105
93
end
@@ -121,7 +109,7 @@ classify(x) = argmax.(eachcol(x))
121
109
function accuracy(model, data, ps, st)
122
110
total_correct, total = 0, 0
123
111
st = Lux.testmode(st)
124
- model = Lux.Experimental. StatefulLuxLayer(model, ps, st)
112
+ model = StatefulLuxLayer(model, ps, st)
125
113
for (x, y) in data
126
114
target_class = classify(cdev(y))
127
115
predicted_class = classify(cdev(model(x)))
@@ -134,48 +122,43 @@ end
134
122
function train_model(
135
123
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
136
124
model, ps, st = construct_model(solver; model_type)
137
- model_st = Lux.Experimental. StatefulLuxLayer(model, nothing, st)
125
+ model_st = StatefulLuxLayer(model, nothing, st)
138
126
139
- @info " Training Model: $(model_type) with Solver: $( nameof(typeof(solver)))"
127
+ @printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))
140
128
141
129
opt_st = Optimisers.setup(Adam(0.001), ps)
142
130
143
131
acc = accuracy(model, data_test, ps, st) * 100
144
- @info " Starting Accuracy: $(acc)"
132
+ @printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc
145
133
146
- @info " Pretrain with unrolling to a depth of 5"
134
+ @printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
147
135
st = Lux.update_state(st, :fixed_depth, Val(5))
148
- model_st = Lux.Experimental. StatefulLuxLayer(model, ps, st)
136
+ model_st = StatefulLuxLayer(model, ps, st)
149
137
150
138
for (i, (x, y)) in enumerate(data_train)
151
139
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
152
140
Optimisers.update!(opt_st, ps, res.grad[3])
153
- if i % 50 == 1
154
- @info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
155
- end
141
+ i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
156
142
end
157
143
158
144
acc = accuracy(model, data_test, ps, model_st.st) * 100
159
- @info " Pretraining complete. Accuracy: $(acc)"
145
+ @printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc
160
146
161
147
st = Lux.update_state(st, :fixed_depth, Val(0))
162
- model_st = Lux.Experimental. StatefulLuxLayer(model, ps, st)
148
+ model_st = StatefulLuxLayer(model, ps, st)
163
149
164
150
for epoch in 1:3
165
151
for (i, (x, y)) in enumerate(data_train)
166
152
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
167
153
Optimisers.update!(opt_st, ps, res.grad[3])
168
- if i % 50 == 1
169
- @info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
170
- end
154
+ i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
171
155
end
172
156
173
157
acc = accuracy(model, data_test, ps, model_st.st) * 100
174
- @info " Epoch: [$(epoch)/3 ] Accuracy: $(acc)"
158
+ @printf "[%s] Epoch: [%d/%d ] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
175
159
end
176
160
177
- @info "Training complete."
178
- println()
161
+ @printf "[%s] Training complete.\n" string(now())
179
162
180
163
return model, ps, st
181
164
end
@@ -187,19 +170,15 @@ and end up using solvers like `Broyden`, but we can simply slap in any of the fa
187
170
from NonlinearSolve.jl. Here we will use Newton-Krylov Method:
188
171
189
172
``` @example basic_mnist_deq
190
- with_logger(filtered_logger) do
191
- train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq)
192
- end
173
+ train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq);
193
174
nothing # hide
194
175
```
195
176
196
177
We can also train a continuous DEQ by passing in an ODE solver. Here we will use ` VCAB3() `
197
178
which tend to be quite fast for continuous Neural Network problems.
198
179
199
180
``` @example basic_mnist_deq
200
- with_logger(filtered_logger) do
201
- train_model(VCAB3(), :deq)
202
- end
181
+ train_model(VCAB3(), :deq);
203
182
nothing # hide
204
183
```
205
184
0 commit comments