Skip to content

Commit 4d6c1ed

Browse files
committed
Fix tests
1 parent 1903cfa commit 4d6c1ed

File tree

12 files changed

+108
-114
lines changed

12 files changed

+108
-114
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DeepEquilibriumNetworks"
22
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
33
authors = ["Avik Pal <avikpal@mit.edu>"]
4-
version = "1.3.0"
4+
version = "1.4.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -36,7 +36,7 @@ NonlinearSolve = "2"
3636
OrdinaryDiffEq = "6"
3737
Reexport = "1"
3838
SciMLBase = "2"
39-
SciMLSensitivity = "7"
39+
SciMLSensitivity = "7.43"
4040
Setfield = "1"
4141
Static = "0.6, 0.7, 0.8"
4242
SteadyStateDiffEq = "1.16"

README.md

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,30 +24,31 @@ Pkg.add("DeepEquilibriumNetworks")
2424
## Quickstart
2525

2626
```julia
27-
import DeepEquilibriumNetworks as DEQs
28-
import Lux
29-
import Random
30-
import Zygote
27+
using DeepEquilibriumNetworks, Lux, Random, Zygote
28+
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support
3129

3230
seed = 0
3331
rng = Random.default_rng()
3432
Random.seed!(rng, seed)
3533

36-
model = Lux.Chain(Lux.Dense(2, 2),
37-
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+,
38-
Lux.Dense(2, 2; use_bias=false),
39-
Lux.Dense(2, 2; use_bias=false)),
40-
DEQs.ContinuousDEQSolver(;
41-
abstol=0.1f0,
42-
reltol=0.1f0,
43-
abstol_termination=0.1f0,
44-
reltol_termination=0.1f0)))
45-
46-
ps, st = gpu.(Lux.setup(rng, model))
47-
x = gpu(rand(rng, Float32, 2, 1))
48-
y = gpu(rand(rng, Float32, 2, 1))
49-
50-
gs = Zygote.gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1]
34+
model = Chain(Dense(2 => 2),
35+
DeepEquilibriumNetwork(Parallel(+,
36+
Dense(2 => 2; use_bias=false),
37+
Dense(2 => 2; use_bias=false)),
38+
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
39+
reltol_termination=0.1f0);
40+
save_everystep=true))
41+
42+
gdev = gpu_device()
43+
cdev = cpu_device()
44+
45+
ps, st = Lux.setup(rng, model) |> gdev
46+
x = rand(rng, Float32, 2, 1) |> gdev
47+
y = rand(rng, Float32, 2, 1) |> gdev
48+
49+
model(x, ps, st)
50+
51+
gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
5152
```
5253

5354
## Citation

docs/src/index.md

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
DeepEquilibriumNetworks.jl is a framework built on top of
44
[DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/) and
5-
[Lux.jl](https://docs.sciml.ai/Lux/stable/), enabling the efficient training and inference for
5+
[Lux.jl](https://lux.csail.mit.edu/), enabling the efficient training and inference for
66
Deep Equilibrium Networks (Infinitely Deep Neural Networks).
77

88
## Installation
@@ -17,30 +17,30 @@ Pkg.add("DeepEquilibriumNetworks")
1717
## Quick-start
1818

1919
```julia
20-
import DeepEquilibriumNetworks as DEQs
21-
import Lux
22-
import Random
23-
import Zygote
20+
using DeepEquilibriumNetworks, Lux, Random, Zygote
21+
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support
2422

2523
seed = 0
2624
rng = Random.default_rng()
2725
Random.seed!(rng, seed)
26+
model = Chain(Dense(2 => 2),
27+
DeepEquilibriumNetwork(Parallel(+,
28+
Dense(2 => 2; use_bias=false),
29+
Dense(2 => 2; use_bias=false)),
30+
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
31+
reltol_termination=0.1f0);
32+
save_everystep=true))
2833

29-
model = Lux.Chain(Lux.Dense(2, 2),
30-
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+,
31-
Lux.Dense(2, 2; use_bias=false),
32-
Lux.Dense(2, 2; use_bias=false)),
33-
DEQs.ContinuousDEQSolver(;
34-
abstol=0.1f0,
35-
reltol=0.1f0,
36-
abstol_termination=0.1f0,
37-
reltol_termination=0.1f0)))
38-
39-
ps, st = gpu.(Lux.setup(rng, model))
40-
x = gpu(rand(rng, Float32, 2, 1))
41-
y = gpu(rand(rng, Float32, 2, 1))
42-
43-
gs = Zygote.gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1]
34+
gdev = gpu_device()
35+
cdev = cpu_device()
36+
37+
ps, st = Lux.setup(rng, model) |> gdev
38+
x = rand(rng, Float32, 2, 1) |> gdev
39+
y = rand(rng, Float32, 2, 1) |> gdev
40+
41+
model(x, ps, st)
42+
43+
gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
4444
```
4545

4646
## Citation

src/DeepEquilibriumNetworks.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import ChainRulesCore as CRC
1616
import ConcreteStructs: @concrete
1717

1818
const DEQs = DeepEquilibriumNetworks
19+
const ∂∅ = CRC.NoTangent()
1920

2021
## FIXME: Uses of nothing was removed in Lux 0.5 with a deprecation. It was not updated
2122
## here
@@ -33,6 +34,13 @@ include("layers/evaluate.jl")
3334

3435
include("chainrules.jl")
3536

37+
# Start of Weird Patches
38+
# Honestly no clue why this is needed! -- probably a whacky fix which shouldn't be ever
39+
# needed.
40+
ZygoteRules.gradtuple1(::NamedTuple{()}) = (nothing, nothing, nothing, nothing, nothing)
41+
ZygoteRules.gradtuple1(x::NamedTuple) = collect(values(x))
42+
# End of Weird Patches
43+
3644
# Useful Shorthand
3745
export DEQs
3846

src/chainrules.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
1+
__backing::CRC.Tangent) = __backing(CRC.backing(Δ))
2+
__backing::Tuple) = __backing.(Δ)
3+
__backing::NamedTuple{F}) where {F} = NamedTuple{F}(__backing(values(Δ)))
4+
__backing(Δ) = Δ
5+
16
function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star::T, u0::T, residual::T,
27
jacobian_loss::R, nfe::Int) where {T, R <: AbstractFloat}
38
function deep_equilibrium_solution_pullback(dsol)
4-
return (CRC.NoTangent(), dsol.z_star, dsol.u0, dsol.residual, dsol.jacobian_loss,
5-
dsol.nfe)
9+
return (∂∅, dsol.z_star, dsol.u0, dsol.residual, dsol.jacobian_loss, dsol.nfe)
610
end
711
return (DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe),
812
deep_equilibrium_solution_pullback)
913
end
1014

1115
function _safe_getfield(x::NamedTuple{fields}, field) where {fields}
12-
return field fields ? getfield(x, field) : CRC.NoTangent()
16+
return field fields ? getfield(x, field) : ∂∅
1317
end
1418

1519
function CRC.rrule(::Type{T}, args...) where {T <: NamedTuple}
1620
y = T(args...)
1721
function nt_pullback(dy)
1822
fields = fieldnames(T)
19-
if dy isa CRC.Tangent
20-
dy = CRC.backing(dy)
21-
end
22-
return (CRC.NoTangent(), _safe_getfield.((dy,), fields)...)
23+
dy isa CRC.Tangent && (dy = CRC.backing(dy))
24+
return (∂∅, _safe_getfield.((dy,), fields)...)
2325
end
2426
return y, nt_pullback
2527
end
@@ -28,20 +30,20 @@ function CRC.rrule(::typeof(Setfield.set), obj, l::Setfield.PropertyLens{field},
2830
val) where {field}
2931
res = Setfield.set(obj, l, val)
3032
function setfield_pullback(Δres)
31-
if Δres isa CRC.Tangent
32-
Δres = CRC.backing(Δres)
33-
end
34-
Δobj = Setfield.set(obj, l, CRC.NoTangent())
35-
return (CRC.NoTangent(), Δobj, CRC.NoTangent(), getfield(Δres, field))
33+
Δres = __backing(Δres)
34+
Δobj = Setfield.set(obj, l, ∂∅)
35+
return (∂∅, Δobj, ∂∅, getfield(Δres, field))
3636
end
3737
return res, setfield_pullback
3838
end
3939

40-
function CRC.rrule(::typeof(_construct_problem), deq::AbstractDEQs, dudt, z, ps, x)
40+
function CRC.rrule(::typeof(_construct_problem), deq::AbstractDEQs, dudt, z,
41+
ps::NamedTuple{F}, x) where {F}
4142
prob = _construct_problem(deq, dudt, z, ps, x)
4243
function ∇_construct_problem(Δ)
43-
return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), Δ.u0,
44-
(; model = Δ.p.ps), Δ.p.x)
44+
Δ = __backing(Δ)
45+
nograds = NamedTuple{F}(ntuple(i -> ∂∅, length(F)))
46+
return (∂∅, ∂∅, ∂∅, Δ.u0, merge(nograds, (; model=Δ.p.ps)), Δ.p.x)
4547
end
4648
return prob, ∇_construct_problem
4749
end

src/layers/evaluate.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ end
1414
@inline _postprocess_output(_, z_star) = z_star
1515

1616
@inline function _construct_problem(::AbstractDEQs, dudt, z, ps, x)
17-
return SteadyStateProblem(ODEFunction{false}(dudt), z,
18-
NamedTuple{(:ps, :x)}((ps.model, x)))
17+
return SteadyStateProblem(ODEFunction{false}(dudt), z, (; ps=ps.model, x))
1918
end
2019

2120
@inline _fix_solution_output(_, x) = x
@@ -48,7 +47,9 @@ function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})
4847

4948
prob = _construct_problem(deq, dudt, z, ps, x)
5049
sol = solve(prob, deq.solver; deq.sensealg, deq.kwargs...)
51-
z_star = sol.u
50+
_z_star = sol.u
51+
# Handle Neural ODEs
52+
z_star = _z_star isa Vector{<:AbstractArray} ? last(_z_star) : _z_star
5253

5354
if _jacobian_regularization(deq)
5455
rng = Lux.replicate(st.rng)

src/layers/mdeq.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,7 @@ end
282282
"""
283283
MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
284284
post_fuse_layer::Union{Nothing,Tuple}, solver, scales;
285-
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
286-
kwargs...)
285+
sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...)
287286
288287
Multiscale Neural ODE with Input Injection.
289288
@@ -334,7 +333,7 @@ See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref)
334333
"""
335334
function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
336335
post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}};
337-
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {N, L}
336+
sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {N, L}
338337
l1 = Parallel(nothing, main_layers...)
339338
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
340339

@@ -357,7 +356,7 @@ function _get_initial_condition(deq::MultiScaleNeuralODE, x, ps, st)
357356
end
358357

359358
@inline function _construct_problem(::MultiScaleNeuralODE, dudt, z, ps, x)
360-
return ODEProblem(ODEFunction{false}(dudt), z, (0.0f0, 1.0f0), ps.model)
359+
return ODEProblem(ODEFunction{false}(dudt), z, (0.0f0, 1.0f0), (; ps=ps.model, x))
361360
end
362361

363362
@inline _fix_solution_output(::MultiScaleNeuralODE, x) = x[end]

0 commit comments

Comments
 (0)