Skip to content

Commit 908c224

Browse files
committed
Update the documentation
1 parent bb61c5f commit 908c224

File tree

5 files changed

+44
-82
lines changed

5 files changed

+44
-82
lines changed

Manifest.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
497497

498498
[[deps.Lux]]
499499
deps = ["ADTypes", "Adapt", "ArrayInterface", "ChainRulesCore", "ConcreteStructs", "ConstructionBase", "FastClosures", "Functors", "GPUArraysCore", "LinearAlgebra", "LuxCore", "LuxDeviceUtils", "LuxLib", "MacroTools", "Markdown", "PrecompileTools", "Preferences", "Random", "Reexport", "Setfield", "Statistics", "WeightInitializers"]
500-
git-tree-sha1 = "295c76513705518749fd4e151d9de77c75049d43"
500+
git-tree-sha1 = "ae13ecbe29ee7432dfd477b233db43c462b6a4ff"
501501
repo-rev = "ap/nested_ad"
502502
repo-url = "https://github.com/LuxDL/Lux.jl.git"
503503
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
@@ -574,9 +574,9 @@ version = "0.1.20"
574574

575575
[[deps.LuxLib]]
576576
deps = ["ArrayInterface", "ChainRulesCore", "FastBroadcast", "FastClosures", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "LuxCore", "Markdown", "NNlib", "PrecompileTools", "Random", "Reexport", "Statistics", "Strided"]
577-
git-tree-sha1 = "7cb3cdf01835d508f2c81e09d2e93f309434b5d6"
577+
git-tree-sha1 = "edbf65f5ceb15ebbfad9d03c6a846d83b9a97baf"
578578
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
579-
version = "0.3.15"
579+
version = "0.3.16"
580580

581581
[deps.LuxLib.extensions]
582582
LuxLibAMDGPUExt = "AMDGPU"

docs/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
[deps]
2+
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
23
DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
34
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
45
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
56
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
6-
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
77
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
88
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
99
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
1010
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
1111
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1212
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1313
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
14+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1415
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1516
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -21,7 +22,6 @@ DeepEquilibriumNetworks = "2"
2122
Documenter = "1"
2223
DocumenterCitations = "1"
2324
LinearSolve = "2"
24-
LoggingExtras = "1"
2525
Lux = "0.5"
2626
LuxCUDA = "0.3"
2727
MLDataUtils = "0.5"

docs/src/tutorials/basic_mnist_deq.md

Lines changed: 19 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ We will train a simple Deep Equilibrium Model on MNIST. First we load a few pack
44

55
```@example basic_mnist_deq
66
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
7-
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
7+
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf
88
using MLDatasets: MNIST
99
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
1010
@@ -20,18 +20,6 @@ const cdev = cpu_device()
2020
const gdev = gpu_device()
2121
```
2222

23-
SciMLBase introduced a warning instead of depwarn which pollutes the output. We can suppress
24-
it with the following logger
25-
26-
```@example basic_mnist_deq
27-
function remove_syms_warning(log_args)
28-
return log_args.message !=
29-
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead."
30-
end
31-
32-
filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger())
33-
```
34-
3523
We can now construct our dataloader.
3624

3725
```@example basic_mnist_deq
@@ -94,12 +82,12 @@ function construct_model(solver; model_type::Symbol=:deq)
9482
x = randn(rng, Float32, 28, 28, 1, 128)
9583
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
9684
97-
model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
98-
@info "warming up forward pass"
85+
model_ = StatefulLuxLayer(model, ps, st)
86+
@printf "[%s] warming up forward pass\n" string(now())
9987
logitcrossentropy(model_, x, ps, y)
100-
@info "warming up backward pass"
88+
@printf "[%s] warming up backward pass\n" string(now())
10189
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
102-
@info "warmup complete"
90+
@printf "[%s] warmup complete\n" string(now())
10391
10492
return model, ps, st
10593
end
@@ -121,7 +109,7 @@ classify(x) = argmax.(eachcol(x))
121109
function accuracy(model, data, ps, st)
122110
total_correct, total = 0, 0
123111
st = Lux.testmode(st)
124-
model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
112+
model = StatefulLuxLayer(model, ps, st)
125113
for (x, y) in data
126114
target_class = classify(cdev(y))
127115
predicted_class = classify(cdev(model(x)))
@@ -134,48 +122,43 @@ end
134122
function train_model(
135123
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
136124
model, ps, st = construct_model(solver; model_type)
137-
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
125+
model_st = StatefulLuxLayer(model, nothing, st)
138126
139-
@info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
127+
@printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))
140128
141129
opt_st = Optimisers.setup(Adam(0.001), ps)
142130
143131
acc = accuracy(model, data_test, ps, st) * 100
144-
@info "Starting Accuracy: $(acc)"
132+
@printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc
145133
146-
@info "Pretrain with unrolling to a depth of 5"
134+
@printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
147135
st = Lux.update_state(st, :fixed_depth, Val(5))
148-
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
136+
model_st = StatefulLuxLayer(model, ps, st)
149137
150138
for (i, (x, y)) in enumerate(data_train)
151139
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
152140
Optimisers.update!(opt_st, ps, res.grad[3])
153-
if i % 50 == 1
154-
@info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
155-
end
141+
i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
156142
end
157143
158144
acc = accuracy(model, data_test, ps, model_st.st) * 100
159-
@info "Pretraining complete. Accuracy: $(acc)"
145+
@printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc
160146
161147
st = Lux.update_state(st, :fixed_depth, Val(0))
162-
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
148+
model_st = StatefulLuxLayer(model, ps, st)
163149
164150
for epoch in 1:3
165151
for (i, (x, y)) in enumerate(data_train)
166152
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
167153
Optimisers.update!(opt_st, ps, res.grad[3])
168-
if i % 50 == 1
169-
@info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
170-
end
154+
i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
171155
end
172156
173157
acc = accuracy(model, data_test, ps, model_st.st) * 100
174-
@info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
158+
@printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
175159
end
176160
177-
@info "Training complete."
178-
println()
161+
@printf "[%s] Training complete.\n" string(now())
179162
180163
return model, ps, st
181164
end
@@ -187,19 +170,15 @@ and end up using solvers like `Broyden`, but we can simply slap in any of the fa
187170
from NonlinearSolve.jl. Here we will use Newton-Krylov Method:
188171

189172
```@example basic_mnist_deq
190-
with_logger(filtered_logger) do
191-
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq)
192-
end
173+
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :regdeq);
193174
nothing # hide
194175
```
195176

196177
We can also train a continuous DEQ by passing in an ODE solver. Here we will use `VCAB3()`
197178
which tend to be quite fast for continuous Neural Network problems.
198179

199180
```@example basic_mnist_deq
200-
with_logger(filtered_logger) do
201-
train_model(VCAB3(), :deq)
202-
end
181+
train_model(VCAB3(), :deq);
203182
nothing # hide
204183
```
205184

docs/src/tutorials/reduced_dim_deq.md

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ same MNIST example as before, but this time we will use a reduced state size.
66

77
```@example reduced_dim_mnist
88
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
9-
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, LoggingExtras
9+
Statistics, Random, Optimisers, LuxCUDA, Zygote, LinearSolve, Dates, Printf
1010
using MLDatasets: MNIST
1111
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs, batchview
1212
@@ -16,13 +16,6 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true
1616
const cdev = cpu_device()
1717
const gdev = gpu_device()
1818
19-
function remove_syms_warning(log_args)
20-
return log_args.message !=
21-
"The use of keyword arguments `syms`, `paramsyms` and `indepsym` for `SciMLFunction`s is deprecated. Pass `sys = SymbolCache(syms, paramsyms, indepsym)` instead."
22-
end
23-
24-
filtered_logger = ActiveFilteredLogger(remove_syms_warning, global_logger())
25-
2619
function onehot(labels_raw)
2720
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
2821
end
@@ -83,12 +76,12 @@ function construct_model(solver; model_type::Symbol=:regdeq)
8376
x = randn(rng, Float32, 28, 28, 1, 128)
8477
y = onehot(rand(Random.default_rng(), 0:9, 128)) |> gdev
8578
86-
model_ = Lux.Experimental.StatefulLuxLayer(model, ps, st)
87-
@info "warming up forward pass"
79+
model_ = StatefulLuxLayer(model, ps, st)
80+
@printf "[%s] warming up forward pass\n" string(now())
8881
logitcrossentropy(model_, x, ps, y)
89-
@info "warming up backward pass"
82+
@printf "[%s] warming up backward pass\n" string(now())
9083
Zygote.gradient(logitcrossentropy, model_, x, ps, y)
91-
@info "warmup complete"
84+
@printf "[%s] warmup complete\n" string(now())
9285
9386
return model, ps, st
9487
end
@@ -110,7 +103,7 @@ classify(x) = argmax.(eachcol(x))
110103
function accuracy(model, data, ps, st)
111104
total_correct, total = 0, 0
112105
st = Lux.testmode(st)
113-
model = Lux.Experimental.StatefulLuxLayer(model, ps, st)
106+
model = StatefulLuxLayer(model, ps, st)
114107
for (x, y) in data
115108
target_class = classify(cdev(y))
116109
predicted_class = classify(cdev(model(x)))
@@ -123,48 +116,43 @@ end
123116
function train_model(
124117
solver, model_type; data_train=zip(x_train, y_train), data_test=zip(x_test, y_test))
125118
model, ps, st = construct_model(solver; model_type)
126-
model_st = Lux.Experimental.StatefulLuxLayer(model, nothing, st)
119+
model_st = StatefulLuxLayer(model, nothing, st)
127120
128-
@info "Training Model: $(model_type) with Solver: $(nameof(typeof(solver)))"
121+
@printf "[%s] Training Model: %s with Solver: %s\n" string(now()) model_type nameof(typeof(solver))
129122
130123
opt_st = Optimisers.setup(Adam(0.001), ps)
131124
132125
acc = accuracy(model, data_test, ps, st) * 100
133-
@info "Starting Accuracy: $(acc)"
126+
@printf "[%s] Starting Accuracy: %.5f%%\n" string(now()) acc
134127
135-
@info "Pretrain with unrolling to a depth of 5"
128+
@printf "[%s] Pretrain with unrolling to a depth of 5\n" string(now())
136129
st = Lux.update_state(st, :fixed_depth, Val(5))
137-
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
130+
model_st = StatefulLuxLayer(model, ps, st)
138131
139132
for (i, (x, y)) in enumerate(data_train)
140133
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
141134
Optimisers.update!(opt_st, ps, res.grad[3])
142-
if i % 50 == 1
143-
@info "Pretraining Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
144-
end
135+
i % 50 == 1 && @printf "[%s] Pretraining Batch: [%4d/%4d] Loss: %.5f\n" string(now()) i length(data_train) res.val
145136
end
146137
147138
acc = accuracy(model, data_test, ps, model_st.st) * 100
148-
@info "Pretraining complete. Accuracy: $(acc)"
139+
@printf "[%s] Pretraining complete. Accuracy: %.5f%%\n" string(now()) acc
149140
150141
st = Lux.update_state(st, :fixed_depth, Val(0))
151-
model_st = Lux.Experimental.StatefulLuxLayer(model, ps, st)
142+
model_st = StatefulLuxLayer(model, ps, st)
152143
153144
for epoch in 1:3
154145
for (i, (x, y)) in enumerate(data_train)
155146
res = Zygote.withgradient(logitcrossentropy, model_st, x, ps, y)
156147
Optimisers.update!(opt_st, ps, res.grad[3])
157-
if i % 50 == 1
158-
@info "Epoch: [$(epoch)/3] Batch: [$(i)/$(length(data_train))] Loss: $(res.val)"
159-
end
148+
i % 50 == 1 && @printf "[%s] Epoch: [%d/%d] Batch: [%4d/%4d] Loss: %.5f\n" string(now()) epoch 3 i length(data_train) res.val
160149
end
161150
162151
acc = accuracy(model, data_test, ps, model_st.st) * 100
163-
@info "Epoch: [$(epoch)/3] Accuracy: $(acc)"
152+
@printf "[%s] Epoch: [%d/%d] Accuracy: %.5f%%\n" string(now()) epoch 3 acc
164153
end
165154
166-
@info "Training complete."
167-
println()
155+
@printf "[%s] Training complete.\n" string(now())
168156
169157
return model, ps, st
170158
end
@@ -174,15 +162,11 @@ Now we can train our model. We can't use `:regdeq` here currently, but we will s
174162
in the future.
175163

176164
```@example reduced_dim_mnist
177-
with_logger(filtered_logger) do
178-
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq)
179-
end
165+
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :skipdeq)
180166
nothing # hide
181167
```
182168

183169
```@example reduced_dim_mnist
184-
with_logger(filtered_logger) do
185-
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq)
186-
end
170+
train_model(NewtonRaphson(; linsolve=KrylovJL_GMRES()), :deq)
187171
nothing # hide
188172
```

src/layers.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,7 @@ julia> model(x, ps, st);
316316
"""
317317
function MultiScaleDeepEquilibriumNetwork(
318318
main_layers::Tuple, mapping_layers::Matrix, post_fuse_layer::Union{Nothing, Tuple},
319-
solver, scales; jacobian_regularization=nothing, kwargs...)
320-
@assert jacobian_regularization===nothing "Jacobian Regularization is not supported yet for MultiScale Models."
319+
solver, scales; kwargs...)
321320
l1 = Parallel(nothing, main_layers...)
322321
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
323322

0 commit comments

Comments
 (0)