Skip to content

Commit a039a66

Browse files
committed
Minor Cleanups
1 parent 903ee76 commit a039a66

File tree

5 files changed

+29
-76
lines changed

5 files changed

+29
-76
lines changed

.JuliaFormatter.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,5 @@ style = "sciml"
22
whitespace_in_kwargs = false
33
always_use_return = true
44
format_docstrings = true
5-
join_lines_based_on_source = false
65
separate_kwargs_with_semicolon = true
76
format_markdown = true

src/DeepEquilibriumNetworks.jl

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,8 @@
11
module DeepEquilibriumNetworks
22

3-
using CUDA,
4-
DiffEqBase,
5-
LinearAlgebra,
6-
LinearSolve,
7-
Lux,
8-
MLUtils,
9-
OrdinaryDiffEq,
10-
Random,
11-
SciMLBase,
12-
SciMLSensitivity,
13-
Setfield,
14-
SimpleNonlinearSolve,
15-
Static,
16-
Statistics,
17-
SteadyStateDiffEq,
18-
Zygote,
19-
ZygoteRules
3+
using CUDA, DiffEqBase, LinearAlgebra, LinearSolve, Lux, MLUtils, OrdinaryDiffEq, Random,
4+
SciMLBase, SciMLSensitivity, Setfield, SimpleNonlinearSolve, Static, Statistics,
5+
SteadyStateDiffEq, Zygote, ZygoteRules
206

217
using DiffEqBase: AbstractSteadyStateProblem
228
using SciMLBase: AbstractNonlinearSolution, AbstractSteadyStateAlgorithm
@@ -48,11 +34,8 @@ export ContinuousDEQSolver, DiscreteDEQSolver
4834
export EquilibriumSolution, DeepEquilibriumSolution, estimate_jacobian_trace
4935

5036
# Networks
51-
export DeepEquilibriumNetwork,
52-
SkipDeepEquilibriumNetwork,
53-
MultiScaleInputLayer,
54-
MultiScaleNeuralODE,
55-
MultiScaleDeepEquilibriumNetwork,
37+
export DeepEquilibriumNetwork, SkipDeepEquilibriumNetwork, MultiScaleInputLayer,
38+
MultiScaleNeuralODE, MultiScaleDeepEquilibriumNetwork,
5639
MultiScaleSkipDeepEquilibriumNetwork
5740

5841
end

src/chainrules.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,7 @@
1-
function CRC.rrule(::Type{<:DeepEquilibriumSolution},
2-
z_star::T,
3-
u0::T,
4-
residual::T,
5-
jacobian_loss::R,
6-
nfe::Int) where {T, R <: AbstractFloat}
1+
function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star::T, u0::T, residual::T,
2+
jacobian_loss::R, nfe::Int) where {T, R <: AbstractFloat}
73
function deep_equilibrium_solution_pullback(dsol)
8-
return (CRC.NoTangent(),
9-
dsol.z_star,
10-
dsol.u0,
11-
dsol.residual,
12-
dsol.jacobian_loss,
4+
return (CRC.NoTangent(), dsol.z_star, dsol.u0, dsol.residual, dsol.jacobian_loss,
135
dsol.nfe)
146
end
157
return (DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe),
@@ -32,9 +24,7 @@ function CRC.rrule(::Type{T}, args...) where {T <: NamedTuple}
3224
return y, nt_pullback
3325
end
3426

35-
function CRC.rrule(::typeof(Setfield.set),
36-
obj,
37-
l::Setfield.PropertyLens{field},
27+
function CRC.rrule(::typeof(Setfield.set), obj, l::Setfield.PropertyLens{field},
3828
val) where {field}
3929
res = Setfield.set(obj, l, val)
4030
function setfield_pullback(Δres)

src/layers/core.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@ abstract type AbstractDeepEquilibriumNetwork <:
44
function Lux.initialstates(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork)
55
rng = Lux.replicate(rng)
66
randn(rng, 1)
7-
return (;
8-
model=Lux.initialstates(rng, deq.model),
9-
fixed_depth=Val(0),
10-
solution=nothing,
11-
rng)
7+
return (; model=Lux.initialstates(rng, deq.model), fixed_depth=Val(0),
8+
solution=nothing, rng)
129
end
1310

1411
function Lux.initialparameters(rng::AbstractRNG, deq::AbstractDeepEquilibriumNetwork)
@@ -21,18 +18,13 @@ abstract type AbstractSkipDeepEquilibriumNetwork <:
2118
function Lux.initialstates(rng::AbstractRNG, deq::AbstractSkipDeepEquilibriumNetwork)
2219
rng = Lux.replicate(rng)
2320
randn(rng, 1)
24-
return (;
25-
model=Lux.initialstates(rng, deq.model),
26-
shortcut=Lux.initialstates(rng, deq.shortcut),
27-
fixed_depth=Val(0),
28-
solution=nothing,
21+
return (; model=Lux.initialstates(rng, deq.model),
22+
shortcut=Lux.initialstates(rng, deq.shortcut), fixed_depth=Val(0), solution=nothing,
2923
rng)
3024
end
3125

32-
const AbstractDEQs = Union{
33-
AbstractDeepEquilibriumNetwork,
34-
AbstractSkipDeepEquilibriumNetwork,
35-
}
26+
const AbstractDEQs = Union{AbstractDeepEquilibriumNetwork,
27+
AbstractSkipDeepEquilibriumNetwork}
3628

3729
function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple)
3830
return deq(x, ps, st, _check_unrolled_mode(st))

src/solve.jl

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
abstract type AbstractDEQSolver <: AbstractSteadyStateAlgorithm end
22

33
"""
4-
ContinuousDEQSolver(alg=VCABM3(); mode=SteadyStateTerminationMode.RelSafeBest,
5-
abstol=1.0f-8, reltol=1.0f-6, abstol_termination=abstol,
6-
reltol_termination=reltol, tspan=Inf32, kwargs...)
4+
ContinuousDEQSolver(alg=VCAB3(); mode=NLSolveTerminationMode.RelSafeBest,
5+
abstol=1.0f-8, reltol=1.0f-6, abstol_termination=abstol, reltol_termination=reltol,
6+
tspan=Inf32, kwargs...)
77
88
Solver for Continuous DEQ Problem [pal2022mixing](@cite). Effectively a wrapper around
99
`DynamicSS` with more sensible defaults for DEQs.
@@ -29,25 +29,18 @@ struct ContinuousDEQSolver{A <: DynamicSS} <: AbstractDEQSolver
2929
alg::A
3030
end
3131

32-
function ContinuousDEQSolver(alg=VCAB3();
33-
mode=NLSolveTerminationMode.RelSafeBest,
34-
abstol=1.0f-8,
35-
reltol=1.0f-6,
36-
abstol_termination=abstol,
37-
reltol_termination=reltol,
38-
tspan=Inf32,
39-
kwargs...)
40-
termination_condition = NLSolveTerminationCondition(mode;
41-
abstol=abstol_termination,
42-
reltol=reltol_termination,
43-
kwargs...)
32+
function ContinuousDEQSolver(alg=VCAB3(); mode=NLSolveTerminationMode.RelSafeBest,
33+
abstol=1.0f-8, reltol=1.0f-6, abstol_termination=abstol, reltol_termination=reltol,
34+
tspan=Inf32, kwargs...)
35+
termination_condition = NLSolveTerminationCondition(mode; abstol=abstol_termination,
36+
reltol=reltol_termination, kwargs...)
4437
return ContinuousDEQSolver(DynamicSS(alg; abstol, reltol, tspan, termination_condition))
4538
end
4639

4740
"""
4841
DiscreteDEQSolver(alg = LBroyden(; batched=true,
49-
termination_condition=NLSolveTerminationCondition(NLSolveTerminationMode.RelSafe;
50-
abstol=1.0f-8, reltol=1.0f-6))
42+
termination_condition=NLSolveTerminationCondition(NLSolveTerminationMode.RelSafe;
43+
abstol=1.0f-8, reltol=1.0f-6))
5144
5245
Solver for Discrete DEQ Problem [baideep2019](@cite). Similar to `SSrootfind` but provides
5346
more flexibility needed for solving DEQ problems.
@@ -60,11 +53,9 @@ See also: [`ContinuousDEQSolver`](@ref)
6053
"""
6154
Base.@kwdef struct DiscreteDEQSolver{A <: AbstractSimpleNonlinearSolveAlgorithm} <:
6255
AbstractDEQSolver
63-
alg::A = LBroyden(;
64-
batched=true,
56+
alg::A = LBroyden(; batched=true,
6557
termination_condition=NLSolveTerminationCondition(NLSolveTerminationMode.RelSafe;
66-
abstol=1.0f-8,
67-
reltol=1.0f-6))
58+
abstol=1.0f-8, reltol=1.0f-6))
6859
end
6960

7061
"""
@@ -84,10 +75,8 @@ end
8475

8576
@truncate_stacktrace EquilibriumSolution 1 2
8677

87-
function DiffEqBase.__solve(prob::AbstractSteadyStateProblem,
88-
alg::AbstractDEQSolver,
89-
args...;
90-
kwargs...)
78+
function DiffEqBase.__solve(prob::AbstractSteadyStateProblem, alg::AbstractDEQSolver,
79+
args...; kwargs...)
9180
sol = solve(prob, alg.alg, args...; kwargs...)
9281

9382
# This is not necessarily true and might fail. But makes the code type stable

0 commit comments

Comments
 (0)