Skip to content

Commit f2326ec

Browse files
committed
Cleanup codebase
1 parent a039a66 commit f2326ec

17 files changed

+188
-454
lines changed

.JuliaFormatter.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ whitespace_in_kwargs = false
33
always_use_return = true
44
format_docstrings = true
55
separate_kwargs_with_semicolon = true
6-
format_markdown = true
6+
format_markdown = true
7+
annotate_untyped_fields_with_any = false

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ profs
99
logs
1010
benchmarking
1111
*/tensorflow_datasets/
12-
checkpoints
12+
checkpoints
13+
wip

Project.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,20 @@ authors = ["Avik Pal <avikpal@mit.edu>"]
44
version = "1.3.0"
55

66
[deps]
7-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
87
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
99
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1212
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1313
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
14+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1415
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1516
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
17+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1618
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1719
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
1820
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
19-
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
2021
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2122
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2223
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
@@ -25,17 +26,18 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2526
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2627

2728
[compat]
28-
CUDA = "3, 4, 5"
2929
ChainRulesCore = "1"
30+
ConcreteStructs = "0.2"
3031
DiffEqBase = "6.119"
3132
LinearSolve = "1, 2"
3233
Lux = "0.4, 0.5"
3334
MLUtils = "0.2, 0.3, 0.4"
35+
NonlinearSolve = "2"
3436
OrdinaryDiffEq = "6"
35-
SciMLBase = "1.19, 2"
37+
Reexport = "1"
38+
SciMLBase = "2"
3639
SciMLSensitivity = "7"
3740
Setfield = "1"
38-
SimpleNonlinearSolve = "0.1.14"
3941
Static = "0.6, 0.7, 0.8"
4042
SteadyStateDiffEq = "1.16"
4143
TruncatedStacktraces = "1.1"

src/DeepEquilibriumNetworks.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
module DeepEquilibriumNetworks
22

3-
using CUDA, DiffEqBase, LinearAlgebra, LinearSolve, Lux, MLUtils, OrdinaryDiffEq, Random,
4-
SciMLBase, SciMLSensitivity, Setfield, SimpleNonlinearSolve, Static, Statistics,
5-
SteadyStateDiffEq, Zygote, ZygoteRules
3+
import Reexport: @reexport
64

7-
using DiffEqBase: AbstractSteadyStateProblem
8-
using SciMLBase: AbstractNonlinearSolution, AbstractSteadyStateAlgorithm
9-
using SimpleNonlinearSolve: AbstractSimpleNonlinearSolveAlgorithm
10-
using TruncatedStacktraces: @truncate_stacktrace
5+
@reexport using Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity
6+
7+
using DiffEqBase, LinearAlgebra, LinearSolve, MLUtils, Random, SciMLBase, SciMLSensitivity,
8+
Setfield, Static, Statistics, SteadyStateDiffEq, Zygote, ZygoteRules
9+
10+
import DiffEqBase: AbstractSteadyStateProblem
11+
import SciMLBase: AbstractNonlinearSolution, AbstractSteadyStateAlgorithm
12+
import NonlinearSolve: AbstractNonlinearSolveAlgorithm
13+
import TruncatedStacktraces: @truncate_stacktrace
1114

1215
import ChainRulesCore as CRC
16+
import ConcreteStructs: @concrete
1317

1418
const DEQs = DeepEquilibriumNetworks
1519

20+
## FIXME: Uses of nothing was removed in Lux 0.5 with a deprecation. It was not updated
21+
## here
22+
Lux.parameterlength(::Nothing) = 0
23+
Lux.statelength(::Nothing) = 0
24+
1625
include("solve.jl")
1726
include("utils.jl")
1827

src/layers/core.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ abstract type AbstractSkipDeepEquilibriumNetwork <:
1818
function Lux.initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork)
1919
rng = Lux.replicate(rng)
2020
randn(rng, 1)
21-
return (; model=Lux.initialstates(rng, deq.model),
22-
shortcut=Lux.initialstates(rng, deq.shortcut), fixed_depth=Val(0), solution=nothing,
23-
rng)
21+
return (; model=Lux.initialstates(rng, deq.model), rng,
22+
shortcut=Lux.initialstates(rng, deq.shortcut), fixed_depth=Val(0), solution=nothing)
2423
end
2524

2625
const AbstractDEQs = Union{AbstractDeepEquilibriumNetwork,
@@ -54,10 +53,10 @@ Stores the solution of a DeepEquilibriumNetwork and its variants.
5453
can be computed).
5554
- `nfe`: Number of Function Evaluations
5655
"""
57-
struct DeepEquilibriumSolution{T, R <: AbstractFloat, TRes}
56+
@concrete struct DeepEquilibriumSolution{T, R <: AbstractFloat}
5857
z_star::T
5958
u0::T
60-
residual::TRes
59+
residual
6160
jacobian_loss::R
62-
nfe::Int
61+
nfe
6362
end

src/layers/deq.jl

Lines changed: 22 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
2222
model = DeepEquilibriumNetwork(Parallel(+,
2323
Dense(2, 2; use_bias=false),
2424
Dense(2, 2; use_bias=false)),
25-
ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0);
26-
save_everystep=true)
25+
ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0); save_everystep=true)
2726
2827
rng = Random.default_rng()
2928
ps, st = Lux.setup(rng, model)
@@ -34,30 +33,18 @@ model(rand(rng, Float32, 2, 1), ps, st)
3433
See also: [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref),
3534
[`MultiScaleSkipDeepEquilibriumNetwork`](@ref).
3635
"""
37-
struct DeepEquilibriumNetwork{J, M, A, S, K} <: AbstractDeepEquilibriumNetwork
38-
model::M
39-
solver::A
40-
sensealg::S
41-
kwargs::K
36+
@concrete struct DeepEquilibriumNetwork{J} <: AbstractDeepEquilibriumNetwork
37+
model
38+
solver
39+
sensealg
40+
kwargs
4241
end
4342

4443
@truncate_stacktrace DeepEquilibriumNetwork 1 2
4544

46-
function DeepEquilibriumNetwork(model,
47-
solver;
48-
jacobian_regularization::Bool=false,
49-
sensealg=SteadyStateAdjoint(),
50-
kwargs...)
51-
return DeepEquilibriumNetwork{
52-
jacobian_regularization,
53-
typeof(model),
54-
typeof(solver),
55-
typeof(sensealg),
56-
typeof(kwargs),
57-
}(model,
58-
solver,
59-
sensealg,
60-
kwargs)
45+
function DeepEquilibriumNetwork(model, solver; jacobian_regularization::Bool=false,
46+
sensealg=SteadyStateAdjoint(), kwargs...)
47+
return DeepEquilibriumNetwork{jacobian_regularization}(model, solver, sensealg, kwargs)
6148
end
6249

6350
_jacobian_regularization(::DeepEquilibriumNetwork{J}) where {J} = J
@@ -91,8 +78,7 @@ model = SkipDeepEquilibriumNetwork(Parallel(+,
9178
Dense(2, 2; use_bias=false),
9279
Dense(2, 2; use_bias=false)),
9380
Dense(2, 2),
94-
ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0);
95-
save_everystep=true)
81+
ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0); save_everystep=true)
9682
9783
rng = Random.default_rng()
9884
ps, st = Lux.setup(rng, model)
@@ -104,8 +90,7 @@ model = SkipDeepEquilibriumNetwork(Parallel(+,
10490
Dense(2, 2; use_bias=false),
10591
Dense(2, 2; use_bias=false)),
10692
nothing,
107-
ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0);
108-
save_everystep=true)
93+
ContinuousDEQSolver(VCABM3(); abstol=0.01f0, reltol=0.01f0); save_everystep=true)
10994
11095
rng = Random.default_rng()
11196
ps, st = Lux.setup(rng, model)
@@ -116,41 +101,25 @@ model(rand(rng, Float32, 2, 1), ps, st)
116101
See also: [`DeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref),
117102
[`MultiScaleSkipDeepEquilibriumNetwork`](@ref)
118103
"""
119-
struct SkipDeepEquilibriumNetwork{J, M, Sh, A, S, K} <: AbstractSkipDeepEquilibriumNetwork
120-
model::M
121-
shortcut::Sh
122-
solver::A
123-
sensealg::S
124-
kwargs::K
104+
@concrete struct SkipDeepEquilibriumNetwork{J} <: AbstractSkipDeepEquilibriumNetwork
105+
model
106+
shortcut
107+
solver
108+
sensealg
109+
kwargs
125110
end
126111

127112
@truncate_stacktrace SkipDeepEquilibriumNetwork 1 2 3
128113

129-
function SkipDeepEquilibriumNetwork(model,
130-
shortcut,
131-
solver;
132-
sensealg=SteadyStateAdjoint(),
133-
jacobian_regularization::Bool=false,
134-
kwargs...)
135-
return SkipDeepEquilibriumNetwork{
136-
jacobian_regularization,
137-
typeof(model),
138-
typeof(shortcut),
139-
typeof(solver),
140-
typeof(sensealg),
141-
typeof(kwargs),
142-
}(model,
143-
shortcut,
144-
solver,
145-
sensealg,
146-
kwargs)
114+
function SkipDeepEquilibriumNetwork(model, shortcut, solver; sensealg=SteadyStateAdjoint(),
115+
jacobian_regularization::Bool=false, kwargs...)
116+
return SkipDeepEquilibriumNetwork{jacobian_regularization}(model, shortcut, solver,
117+
sensealg, kwargs)
147118
end
148119

149120
_jacobian_regularization(::SkipDeepEquilibriumNetwork{J}) where {J} = J
150121

151-
function _get_initial_condition(deq::SkipDeepEquilibriumNetwork{J, M, Nothing},
152-
x,
153-
ps,
122+
function _get_initial_condition(deq::SkipDeepEquilibriumNetwork{J, M, Nothing}, x, ps,
154123
st) where {J, M}
155124
z, st_ = deq.model((zero(x), x), ps.model, st.model)
156125
@set! st.model = st_

src/layers/evaluate.jl

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
@generated function _evaluate_unrolled_model(::AbstractDEQs,
2-
model,
3-
z_star,
4-
x,
5-
ps,
6-
st,
1+
@generated function _evaluate_unrolled_model(::AbstractDEQs, model, z_star, x, ps, st,
72
::Val{d}) where {d}
83
calls = [:((z_star, st) = model((z_star, x), ps, st)) for _ in 1:d]
94
push!(calls, :(return z_star, st))
@@ -29,12 +24,7 @@ function (deq::AbstractDEQs)(x::AbstractArray{T}, ps, st::NamedTuple, ::Val{true
2924
z, st = _get_initial_condition(deq, x, ps, st)
3025
depth = _get_unrolled_depth(st)
3126

32-
z_star, st_ = _evaluate_unrolled_model(deq,
33-
deq.model,
34-
z,
35-
x,
36-
ps.model,
37-
st.model,
27+
z_star, st_ = _evaluate_unrolled_model(deq, deq.model, z, x, ps.model, st.model,
3828
st.fixed_depth)
3929

4030
@set! st.model = st_
@@ -61,13 +51,8 @@ function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})
6151

6252
if _jacobian_regularization(deq)
6353
rng = Lux.replicate(st.rng)
64-
jac_loss = estimate_jacobian_trace(Val(:finite_diff),
65-
deq.model,
66-
ps.model,
67-
st.model,
68-
z_star,
69-
x,
70-
rng)
54+
jac_loss = estimate_jacobian_trace(Val(:finite_diff), deq.model, ps.model, st.model,
55+
z_star, x, rng)
7156
else
7257
rng = st.rng
7358
jac_loss = T(0)

src/layers/jacobian_stabilization.jl

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
_gaussian_like(rng::AbstractRNG, x::AbstractArray) = randn(rng, eltype(x), size(x))
2-
_gaussian_like(rng::AbstractRNG, x::CuArray) = CUDA.randn(rng, eltype(x), size(x))
1+
function _gaussian_like(rng::AbstractRNG, x)
2+
y = similar(x)
3+
randn!(rng, y)
4+
return y
5+
end
6+
7+
CRC.@non_differentiable _gaussian_like(::Any...)
38

49
"""
510
estimate_jacobian_trace(::Val{mode}, model::Lux.AbstractExplicitLayer, ps,
@@ -23,25 +28,15 @@ Estimates the trace of the jacobian matrix wrt `z`.
2328
2429
Stochastic Estimate of the trace of the Jacobian.
2530
"""
26-
function estimate_jacobian_trace(::Val{:reverse},
27-
model::Lux.AbstractExplicitLayer,
28-
ps,
29-
st::NamedTuple,
30-
z::AbstractArray,
31-
x::AbstractArray,
32-
rng::AbstractRNG)
31+
function estimate_jacobian_trace(::Val{:reverse}, model::Lux.AbstractExplicitLayer,
32+
ps, st::NamedTuple, z::AbstractArray, x::AbstractArray, rng::AbstractRNG)
3333
_, back = Zygote.pullback(u -> model((u, x), ps, st)[1], z)
3434
vjp_z = back(_gaussian_like(rng, x))[1]
3535
return mean(abs2, vjp_z)
3636
end
3737

38-
function estimate_jacobian_trace(::Val{:finite_diff},
39-
model::Lux.AbstractExplicitLayer,
40-
ps,
41-
st::NamedTuple,
42-
z::AbstractArray,
43-
x::AbstractArray,
44-
rng::AbstractRNG)
38+
function estimate_jacobian_trace(::Val{:finite_diff}, model::Lux.AbstractExplicitLayer,
39+
ps, st::NamedTuple, z::AbstractArray, x::AbstractArray, rng::AbstractRNG)
4540
f = u -> model((u, x), ps, st)[1]
4641
res = convert(eltype(z), 0)
4742
epsilon = cbrt(eps(typeof(res)))

0 commit comments

Comments
 (0)