Skip to content

Commit ddc5efb

Browse files
authored
Merge pull request #93 from SciML/ap/ssadjoint_better
Better Steady State Adjoint
2 parents 903ee76 + 34932dc commit ddc5efb

40 files changed

+372
-52135
lines changed

.JuliaFormatter.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@ style = "sciml"
22
whitespace_in_kwargs = false
33
always_use_return = true
44
format_docstrings = true
5-
join_lines_based_on_source = false
65
separate_kwargs_with_semicolon = true
7-
format_markdown = true
6+
format_markdown = true
7+
annotate_untyped_fields_with_any = false

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ jobs:
2727
- ADJOINT
2828
version:
2929
- '1'
30-
- '1.6'
3130
steps:
3231
- uses: actions/checkout@v4
3332
- uses: julia-actions/setup-julia@v1

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ profs
99
logs
1010
benchmarking
1111
*/tensorflow_datasets/
12-
checkpoints
12+
checkpoints
13+
wip

Project.toml

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,40 @@
11
name = "DeepEquilibriumNetworks"
22
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
33
authors = ["Avik Pal <avikpal@mit.edu>"]
4-
version = "1.3.0"
4+
version = "1.4.0"
55

66
[deps]
7-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
87
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
99
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1212
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
13-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
13+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1414
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1515
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
16+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1617
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1718
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
1819
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
19-
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
20-
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2120
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2221
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
2322
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
2423
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
25-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2624

2725
[compat]
28-
CUDA = "3, 4, 5"
2926
ChainRulesCore = "1"
27+
ConcreteStructs = "0.2"
3028
DiffEqBase = "6.119"
3129
LinearSolve = "1, 2"
32-
Lux = "0.4, 0.5"
33-
MLUtils = "0.2, 0.3, 0.4"
30+
Lux = "0.5.7"
31+
NonlinearSolve = "2"
3432
OrdinaryDiffEq = "6"
35-
SciMLBase = "1.19, 2"
36-
SciMLSensitivity = "7"
33+
Reexport = "1"
34+
SciMLBase = "2"
35+
SciMLSensitivity = "7.43"
3736
Setfield = "1"
38-
SimpleNonlinearSolve = "0.1.14"
39-
Static = "0.6, 0.7, 0.8"
4037
SteadyStateDiffEq = "1.16"
4138
TruncatedStacktraces = "1.1"
4239
Zygote = "0.6.34"
43-
ZygoteRules = "0.2"
44-
julia = "1.6"
40+
julia = "1.9"

README.md

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,31 @@ Pkg.add("DeepEquilibriumNetworks")
2424
## Quickstart
2525

2626
```julia
27-
import DeepEquilibriumNetworks as DEQs
28-
import Lux
29-
import Random
30-
import Zygote
27+
using DeepEquilibriumNetworks, Lux, Random, Zygote
28+
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support
3129

3230
seed = 0
3331
rng = Random.default_rng()
3432
Random.seed!(rng, seed)
3533

36-
model = Lux.Chain(Lux.Dense(2, 2),
37-
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+,
38-
Lux.Dense(2, 2; use_bias=false),
39-
Lux.Dense(2, 2; use_bias=false)),
40-
DEQs.ContinuousDEQSolver(;
41-
abstol=0.1f0,
42-
reltol=0.1f0,
43-
abstol_termination=0.1f0,
44-
reltol_termination=0.1f0)))
45-
46-
ps, st = gpu.(Lux.setup(rng, model))
47-
x = gpu(rand(rng, Float32, 2, 1))
48-
y = gpu(rand(rng, Float32, 2, 1))
49-
50-
gs = Zygote.gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1]
34+
model = Chain(Dense(2 => 2),
35+
DeepEquilibriumNetwork(Parallel(+,
36+
Dense(2 => 2; use_bias=false),
37+
Dense(2 => 2; use_bias=false)),
38+
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
39+
reltol_termination=0.1f0);
40+
save_everystep=true))
41+
42+
gdev = gpu_device()
43+
cdev = cpu_device()
44+
45+
ps, st = Lux.setup(rng, model) |> gdev
46+
x = rand(rng, Float32, 2, 1) |> gdev
47+
y = rand(rng, Float32, 2, 1) |> gdev
48+
49+
model(x, ps, st)
50+
51+
gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
5152
```
5253

5354
## Citation

docs/src/index.md

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
DeepEquilibriumNetworks.jl is a framework built on top of
44
[DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/) and
5-
[Lux.jl](https://docs.sciml.ai/Lux/stable/), enabling the efficient training and inference for
5+
[Lux.jl](https://lux.csail.mit.edu/), enabling the efficient training and inference for
66
Deep Equilibrium Networks (Infinitely Deep Neural Networks).
77

88
## Installation
@@ -17,30 +17,30 @@ Pkg.add("DeepEquilibriumNetworks")
1717
## Quick-start
1818

1919
```julia
20-
import DeepEquilibriumNetworks as DEQs
21-
import Lux
22-
import Random
23-
import Zygote
20+
using DeepEquilibriumNetworks, Lux, Random, Zygote
21+
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support
2422

2523
seed = 0
2624
rng = Random.default_rng()
2725
Random.seed!(rng, seed)
26+
model = Chain(Dense(2 => 2),
27+
DeepEquilibriumNetwork(Parallel(+,
28+
Dense(2 => 2; use_bias=false),
29+
Dense(2 => 2; use_bias=false)),
30+
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
31+
reltol_termination=0.1f0);
32+
save_everystep=true))
2833

29-
model = Lux.Chain(Lux.Dense(2, 2),
30-
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+,
31-
Lux.Dense(2, 2; use_bias=false),
32-
Lux.Dense(2, 2; use_bias=false)),
33-
DEQs.ContinuousDEQSolver(;
34-
abstol=0.1f0,
35-
reltol=0.1f0,
36-
abstol_termination=0.1f0,
37-
reltol_termination=0.1f0)))
38-
39-
ps, st = gpu.(Lux.setup(rng, model))
40-
x = gpu(rand(rng, Float32, 2, 1))
41-
y = gpu(rand(rng, Float32, 2, 1))
42-
43-
gs = Zygote.gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1]
34+
gdev = gpu_device()
35+
cdev = cpu_device()
36+
37+
ps, st = Lux.setup(rng, model) |> gdev
38+
x = rand(rng, Float32, 2, 1) |> gdev
39+
y = rand(rng, Float32, 2, 1) |> gdev
40+
41+
model(x, ps, st)
42+
43+
gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
4444
```
4545

4646
## Citation

docs/src/manual/misc.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44
DeepEquilibriumSolution
55
EquilibriumSolution
66
DeepEquilibriumNetworks.split_and_reshape
7-
DeepEquilibriumNetworks.init_identity_matrix
87
DeepEquilibriumNetworks.estimate_jacobian_trace
98
```

experiments/Project.toml

Lines changed: 0 additions & 44 deletions
This file was deleted.

experiments/cifar10/large.yml

Lines changed: 0 additions & 74 deletions
This file was deleted.

0 commit comments

Comments
 (0)