Skip to content

Commit 1a0de74

Browse files
committed
Format
1 parent cfffce6 commit 1a0de74

24 files changed

+1133
-519
lines changed

README.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ rng = Random.default_rng()
3434
Random.seed!(rng, seed)
3535

3636
model = Lux.Chain(Lux.Dense(2, 2),
37-
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+, Lux.Dense(2, 2; bias=false),
38-
Lux.Dense(2, 2; bias=false)),
39-
DEQs.ContinuousDEQSolver(; abstol=0.1f0,
40-
reltol=0.1f0,
41-
abstol_termination=0.1f0,
42-
reltol_termination=0.1f0)))
37+
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+,
38+
Lux.Dense(2, 2; bias=false),
39+
Lux.Dense(2, 2; bias=false)),
40+
DEQs.ContinuousDEQSolver(;
41+
abstol=0.1f0,
42+
reltol=0.1f0,
43+
abstol_termination=0.1f0,
44+
reltol_termination=0.1f0)))
4345

4446
ps, st = gpu.(Lux.setup(rng, model))
4547
x = gpu(rand(rng, Float32, 2, 1))

docs/make.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,25 @@ bib = CitationBibliography(joinpath(@__DIR__, "ref.bib"); sorting=:nyt)
77

88
include("pages.jl")
99

10-
makedocs(bib; sitename="Fast Deep Equilibrium Networks", authors="Avik Pal et al.",
11-
clean=true, doctest=false, modules=[DeepEquilibriumNetworks],
12-
strict=[
13-
:doctest,
14-
:linkcheck,
15-
:parse_error,
16-
:example_block,
17-
# Other available options are
18-
# :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block,
19-
# :footnote, :meta_block, :missing_docs, :setup_block
20-
], checkdocs=:all,
21-
format=Documenter.HTML(; assets=["assets/favicon.ico"],
22-
canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"),
23-
pages=pages)
10+
makedocs(bib;
11+
sitename="Fast Deep Equilibrium Networks",
12+
authors="Avik Pal et al.",
13+
clean=true,
14+
doctest=false,
15+
modules=[DeepEquilibriumNetworks],
16+
strict=[
17+
:doctest,
18+
:linkcheck,
19+
:parse_error,
20+
:example_block,
21+
# Other available options are
22+
# :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block,
23+
# :footnote, :meta_block, :missing_docs, :setup_block
24+
],
25+
checkdocs=:all,
26+
format=Documenter.HTML(;
27+
assets=["assets/favicon.ico"],
28+
canonical="https://docs.sciml.ai/DeepEquilibriumNetworks/stable/"),
29+
pages=pages)
2430

2531
deploydocs(; repo="github.com/SciML/DeepEquilibriumNetworks.jl.git", push_preview=true)

docs/src/index.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ rng = Random.default_rng()
2727
Random.seed!(rng, seed)
2828

2929
model = Lux.Chain(Lux.Dense(2, 2),
30-
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+, Lux.Dense(2, 2; bias=false),
31-
Lux.Dense(2, 2; bias=false)),
32-
DEQs.ContinuousDEQSolver(; abstol=0.1f0,
33-
reltol=0.1f0,
34-
abstol_termination=0.1f0,
35-
reltol_termination=0.1f0)))
30+
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+,
31+
Lux.Dense(2, 2; bias=false),
32+
Lux.Dense(2, 2; bias=false)),
33+
DEQs.ContinuousDEQSolver(;
34+
abstol=0.1f0,
35+
reltol=0.1f0,
36+
abstol_termination=0.1f0,
37+
reltol_termination=0.1f0)))
3638

3739
ps, st = gpu.(Lux.setup(rng, model))
3840
x = gpu(rand(rng, Float32, 2, 1))

experiments/cifar10/main.jl

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
1-
import CUDA, DEQExperiments, FluxMPI, Logging, Lux, OneHotArrays, Optimisers, PyCall,
2-
Random, Setfield, SimpleConfig, Statistics, Wandb
1+
import CUDA,
2+
DEQExperiments,
3+
FluxMPI,
4+
Logging,
5+
Lux,
6+
OneHotArrays,
7+
Optimisers,
8+
PyCall,
9+
Random,
10+
Setfield,
11+
SimpleConfig,
12+
Statistics,
13+
Wandb
314
import Lux.Training
415
import ComponentArrays as CA
516

@@ -10,8 +21,10 @@ function get_dataloaders(; augment, data_root, eval_batchsize, train_batchsize)
1021

1122
tf.config.set_visible_devices([], "GPU")
1223

13-
ds_train, ds_test = tfds.load("cifar10"; split=["train", "test"], as_supervised=true,
14-
data_dir=data_root)
24+
ds_train, ds_test = tfds.load("cifar10";
25+
split=["train", "test"],
26+
as_supervised=true,
27+
data_dir=data_root)
1528

1629
image_mean = tf.constant([[[0.4914f0, 0.4822f0, 0.4465f0]]])
1730
image_std = tf.constant([[[0.2023f0, 0.1994f0, 0.2010f0]]])
@@ -50,12 +63,12 @@ function get_dataloaders(; augment, data_root, eval_batchsize, train_batchsize)
5063
ds_test = ds_test.prefetch(tf.data.AUTOTUNE).repeat(1)
5164

5265
return (tfds.as_numpy(ds_train.batch(train_batchsize)),
53-
tfds.as_numpy(ds_test.batch(eval_batchsize)))
66+
tfds.as_numpy(ds_test.batch(eval_batchsize)))
5467
end
5568

5669
function _data_postprocess(image, label)
5770
return (Lux.gpu(permutedims(image, (3, 2, 4, 1))),
58-
Lux.gpu(OneHotArrays.onehotbatch(label, 0:9)))
71+
Lux.gpu(OneHotArrays.onehotbatch(label, 0:9)))
5972
end
6073

6174
function main(filename, args)
@@ -85,12 +98,18 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
8598
end
8699
vjp_rule = Training.ZygoteVJP()
87100

88-
DEQExperiments.warmup_model(loss_function, model, tstate.parameters, tstate.states, cfg;
89-
transform_input=Lux.gpu)
90-
91-
ds_train, ds_test = get_dataloaders(; cfg.dataset.augment, cfg.dataset.data_root,
92-
cfg.dataset.eval_batchsize,
93-
cfg.dataset.train_batchsize)
101+
DEQExperiments.warmup_model(loss_function,
102+
model,
103+
tstate.parameters,
104+
tstate.states,
105+
cfg;
106+
transform_input=Lux.gpu)
107+
108+
ds_train, ds_test = get_dataloaders(;
109+
cfg.dataset.augment,
110+
cfg.dataset.data_root,
111+
cfg.dataset.eval_batchsize,
112+
cfg.dataset.train_batchsize)
94113
_, ds_train_iter = iterate(ds_train)
95114

96115
# Setup
@@ -123,9 +142,11 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
123142
end
124143

125144
# Setup Logging
126-
loggers = DEQExperiments.create_logger(log_dir, cfg.train.total_steps - initial_step,
127-
cfg.train.total_steps - initial_step, expt_name,
128-
SimpleConfig.flatten_configuration(cfg))
145+
loggers = DEQExperiments.create_logger(log_dir,
146+
cfg.train.total_steps - initial_step,
147+
cfg.train.total_steps - initial_step,
148+
expt_name,
149+
SimpleConfig.flatten_configuration(cfg))
129150

130151
best_test_accuracy = 0
131152

@@ -145,7 +166,7 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
145166
# LR Update
146167
lr_new = sched(step + 1)
147168
Setfield.@set! tstate.optimizer_state = Optimisers.adjust(tstate.optimizer_state,
148-
lr_new)
169+
lr_new)
149170

150171
accuracy = DEQExperiments.accuracy(Lux.cpu(stats.y_pred), Lux.cpu(y))
151172
residual = abs(Statistics.mean(stats.residual))
@@ -154,7 +175,8 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
154175
loggers.progress_loggers.train.avg_meters.batch_time(data_time +
155176
step_stats.fwd_time +
156177
step_stats.bwd_time +
157-
step_stats.opt_time, bsize)
178+
step_stats.opt_time,
179+
bsize)
158180
loggers.progress_loggers.train.avg_meters.data_time(data_time, bsize)
159181
loggers.progress_loggers.train.avg_meters.fwd_time(step_stats.fwd_time, bsize)
160182
loggers.progress_loggers.train.avg_meters.bwd_time(step_stats.bwd_time, bsize)
@@ -208,7 +230,7 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
208230
loggers.progress_loggers.eval.avg_meters.ce_loss(stats.ce_loss, bsize)
209231
loggers.progress_loggers.eval.avg_meters.skip_loss(stats.skip_loss, bsize)
210232
loggers.progress_loggers.eval.avg_meters.residual(abs(Statistics.mean(stats.residual)),
211-
bsize)
233+
bsize)
212234
loggers.progress_loggers.eval.avg_meters.top1(acc, bsize)
213235
loggers.progress_loggers.eval.avg_meters.top5(-1, bsize)
214236
loggers.progress_loggers.eval.avg_meters.nfe(stats.nfe, bsize)
@@ -233,8 +255,9 @@ function main(config_name::String, cfg::DEQExperiments.ExperimentConfig)
233255
end
234256

235257
ckpt = (tstate=tstate, step=initial_step)
236-
DEQExperiments.save_checkpoint(ckpt; is_best,
237-
filename=joinpath(ckpt_dir, "model_$(step).jlso"))
258+
DEQExperiments.save_checkpoint(ckpt;
259+
is_best,
260+
filename=joinpath(ckpt_dir, "model_$(step).jlso"))
238261
end
239262
end
240263

0 commit comments

Comments
 (0)