Skip to content

Commit dd44918

Browse files
Avik Palavik-pal
authored andcommitted
start a tutorial
1 parent 76b8f2d commit dd44918

File tree

4 files changed

+45
-12
lines changed

4 files changed

+45
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ y = rand(rng, Float32, 2, 3) |> gdev
4545

4646
model(x, ps, st)
4747

48-
gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
48+
gs = only(Zygote.gradient(p -> sum(abs2, first(model(x, p, st)) .- y), ps))
4949
```
5050

5151
## Citation

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1313
[compat]
1414
DeepEquilibriumNetworks = "2"
1515
Documenter = "1"
16-
DocumenterCitations = "0.2, 1"
16+
DocumenterCitations = "1"

docs/src/tutorials/basic_mnist_deq.md

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,44 @@
1-
# Training a Simple MNIST Classifier with DEQ
1+
# Training a Simple MNIST Classifier using Deep Equilibrium Models
22

3-
This Tutorial is currently under preparation. Check back soon.
3+
We will train a simple Deep Equilibrium Model on MNIST. First we load a few packages.
4+
5+
```@example basic_mnist_deq
6+
using DeepEquilibriumNetworks, SciMLSensitivity, Lux, NonlinearSolve, OrdinaryDiffEq,
7+
Statistics, Random, Optimization, OptimizationOptimisers
8+
using LuxCUDA
9+
using MLDatasets: MNIST
10+
using MLDataUtils: LabelEnc, convertlabel, stratifiedobs
11+
12+
CUDA.allowscalar(false)
13+
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
14+
```
15+
16+
Setup device functions from Lux. See
17+
[GPU Management](https://lux.csail.mit.edu/dev/manual/gpu_management) for more details.
18+
19+
```@example basic_mnist_deq
20+
const cdev = cpu_device()
21+
const gdev = gpu_device()
22+
```
23+
24+
We can now construct our dataloader.
25+
26+
```@example basic_mnist_deq
27+
function onehot(labels_raw)
28+
return convertlabel(LabelEnc.OneOfK, labels_raw, LabelEnc.NativeLabels(collect(0:9)))
29+
end
30+
31+
function loadmnist(batchsize)
32+
# Load MNIST
33+
mnist = MNIST(; split=:train)
34+
imgs, labels_raw = mnist.features, mnist.targets
35+
# Process images into (H,W,C,BS) batches
36+
x_train = Float32.(reshape(imgs, size(imgs, 1), size(imgs, 2), 1, size(imgs, 3))) |>
37+
gdev
38+
x_train = batchview(x_train, batchsize)
39+
# Onehot and batch the labels
40+
y_train = onehot(labels_raw) |> gdev
41+
y_train = batchview(y_train, batchsize)
42+
return x_train, y_train
43+
end
44+
```

src/DeepEquilibriumNetworks.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@ const DEQs = DeepEquilibriumNetworks
1818
include("layers.jl")
1919
include("utils.jl")
2020

21-
## FIXME: Remove once Manifest is removed
22-
using SciMLBase, SciMLSensitivity
23-
24-
@inline __default_sensealg(::SteadyStateProblem) = SteadyStateAdjoint(;
25-
autojacvec=ZygoteVJP(), linsolve_kwargs=(; maxiters=10, abstol=1e-3, reltol=1e-3))
26-
@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())
27-
## FIXME: Remove once Manifest is removed
28-
2921
# Exports
3022
export DEQs, DeepEquilibriumSolution, DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork,
3123
MultiScaleDeepEquilibriumNetwork, MultiScaleSkipDeepEquilibriumNetwork,

0 commit comments

Comments
 (0)