Skip to content

Commit 8e7f132

Browse files
committed
Add adjoint tests
1 parent c93bf03 commit 8e7f132

File tree

9 files changed

+77
-50
lines changed

9 files changed

+77
-50
lines changed

CITATION.bib

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
@misc{pal2022mixing,
2-
title={Mixing Implicit and Explicit Deep Learning with Skip DEQs and Infinite Time Neural ODEs (Continuous DEQs)},
3-
author={Avik Pal and Alan Edelman and Christopher Rackauckas},
4-
year={2022},
5-
eprint={2201.12240},
6-
archivePrefix={arXiv},
7-
primaryClass={cs.LG}
1+
@article{pal2022continuous,
2+
title={Continuous Deep Equilibrium Models: Training Neural ODEs Faster by Integrating Them to Infinity},
3+
author={Pal, Avik and Edelman, Alan and Rackauckas, Christopher},
4+
booktitle={2023 IEEE High Performance Extreme Computing Conference (HPEC)},
5+
year={2023}
86
}

README.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,23 @@ Pkg.add("DeepEquilibriumNetworks")
2424
## Quickstart
2525

2626
```julia
27-
using DeepEquilibriumNetworks, Lux, Random, Zygote
27+
using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote
2828
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support
2929

3030
seed = 0
3131
rng = Random.default_rng()
3232
Random.seed!(rng, seed)
3333

3434
model = Chain(Dense(2 => 2),
35-
DeepEquilibriumNetwork(Parallel(+,
36-
Dense(2 => 2; use_bias=false),
37-
Dense(2 => 2; use_bias=false)),
38-
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
39-
reltol_termination=0.1f0);
40-
save_everystep=true))
35+
DeepEquilibriumNetwork(Parallel(+, Dense(2 => 2; use_bias=false),
36+
Dense(2 => 2; use_bias=false)), NewtonRaphson()))
4137

4238
gdev = gpu_device()
4339
cdev = cpu_device()
4440

4541
ps, st = Lux.setup(rng, model) |> gdev
46-
x = rand(rng, Float32, 2, 1) |> gdev
47-
y = rand(rng, Float32, 2, 1) |> gdev
42+
x = rand(rng, Float32, 2, 3) |> gdev
43+
y = rand(rng, Float32, 2, 3) |> gdev
4844

4945
model(x, ps, st)
5046

docs/src/index.md

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,23 @@ Pkg.add("DeepEquilibriumNetworks")
1717
## Quick-start
1818

1919
```julia
20-
using DeepEquilibriumNetworks, Lux, Random, Zygote
20+
using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote
2121
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support
2222

2323
seed = 0
2424
rng = Random.default_rng()
2525
Random.seed!(rng, seed)
26+
2627
model = Chain(Dense(2 => 2),
27-
DeepEquilibriumNetwork(Parallel(+,
28-
Dense(2 => 2; use_bias=false),
29-
Dense(2 => 2; use_bias=false)),
30-
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
31-
reltol_termination=0.1f0);
32-
save_everystep=true))
28+
DeepEquilibriumNetwork(Parallel(+, Dense(2 => 2; use_bias=false),
29+
Dense(2 => 2; use_bias=false)), NewtonRaphson()))
3330

3431
gdev = gpu_device()
3532
cdev = cpu_device()
3633

3734
ps, st = Lux.setup(rng, model) |> gdev
38-
x = rand(rng, Float32, 2, 1) |> gdev
39-
y = rand(rng, Float32, 2, 1) |> gdev
35+
x = rand(rng, Float32, 2, 3) |> gdev
36+
y = rand(rng, Float32, 2, 3) |> gdev
4037

4138
model(x, ps, st)
4239

ext/DeepEquilibriumNetworksSciMLSensitivityExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using SciMLBase, SciMLSensitivity
44
import DeepEquilibriumNetworks: __default_sensealg
55

66
@inline __default_sensealg(::SteadyStateProblem) = SteadyStateAdjoint(;
7-
autojacvec=ZygoteVJP())
7+
autojacvec=ZygoteVJP(), linsolve_kwargs=(; maxiters=10, abstol=1e-3, reltol=1e-3))
88
@inline __default_sensealg(::ODEProblem) = GaussAdjoint(; autojacvec=ZygoteVJP())
99

1010
end

src/layers.jl

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,22 @@ Stores the solution of a DeepEquilibriumNetwork and its variants.
2222
original
2323
end
2424

25+
function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star, u0, residual, jacobian_loss,
26+
nfe, original)
27+
sol = DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe, original)
28+
∇DeepEquilibriumSolution(::CRC.NoTangent) = ntuple(_ -> CRC.NoTangent(), 7)
29+
function ∇DeepEquilibriumSolution(∂sol)
30+
∂z_star = ∂sol.z_star
31+
∂u0 = ∂sol.u0
32+
∂residual = ∂sol.residual
33+
∂jacobian_loss = ∂sol.jacobian_loss
34+
∂nfe = ∂sol.nfe
35+
∂original = CRC.NoTangent()
36+
return (CRC.NoTangent(), ∂z_star, ∂u0, ∂residual, ∂jacobian_loss, ∂nfe, ∂original)
37+
end
38+
return sol, ∇DeepEquilibriumSolution
39+
end
40+
2541
function DeepEquilibriumSolution()
2642
return DeepEquilibriumSolution(ntuple(Returns(nothing), 4)..., 0, nothing)
2743
end
@@ -66,8 +82,8 @@ function (deq::DEQ)(x, ps, st::NamedTuple, ::Val{true})
6682
repeated_model = RepeatedLayer(deq.model; repeats=st.fixed_depth)
6783

6884
zˢᵗᵃʳ, st_ = repeated_model((z, x), ps.model, st.model)
69-
model = Lux.Experimental.StatefulLuxLayer(deq.model, ps.model, st_)
70-
resid = CRC.ignore_derivatives(zˢᵗᵃʳ .- model((zˢᵗᵃʳ, x)))
85+
model = Lux.Experimental.StatefulLuxLayer(deq.model, nothing, st_)
86+
resid = CRC.ignore_derivatives(zˢᵗᵃʳ .- model((zˢᵗᵃʳ, x), ps.model))
7187

7288
rng = Lux.replicate(st.rng)
7389
jac_loss = __estimate_jacobian_trace(__getproperty(deq, Val(:jacobian_regularization)),
@@ -156,7 +172,7 @@ function DeepEquilibriumNetwork(model, solver; init=missing,
156172
if init === missing # Regular DEQ
157173
init = WrappedFunction(Base.Fix1(__zeros_init, __getproperty(model, Val(:scales))))
158174
elseif init === nothing # SkipRegDEQ
159-
init = nothing
175+
init = NoOpLayer()
160176
elseif !(init isa AbstractExplicitLayer)
161177
init = Lux.transform(init)
162178
end
@@ -225,8 +241,7 @@ model(x, ps, st)
225241
```
226242
"""
227243
function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
228-
post_fuse_layer::Union{Nothing, Tuple}, solver,
229-
scales::NTuple{N, NTuple{L, Int64}}; kwargs...) where {N, L}
244+
post_fuse_layer::Union{Nothing, Tuple}, solver, scales; kwargs...)
230245
l1 = Parallel(nothing, main_layers...)
231246
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
232247

@@ -254,8 +269,7 @@ creates a [`MultiScaleDeepEquilibriumNetwork`](@ref) with `init` kwarg set to pa
254269
If `init` is not passed, it creates a MultiScale Regularized Deep Equilibrium Network.
255270
"""
256271
function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
257-
post_fuse_layer::Union{Nothing, Tuple}, init::Tuple, solver,
258-
scales::NTuple{N, NTuple{L, Int64}}; kwargs...) where {N, L}
272+
post_fuse_layer::Union{Nothing, Tuple}, init::Tuple, solver, scales; kwargs...)
259273
init = Chain(Parallel(nothing, init...), x -> mapreduce(__flatten, vcat, x))
260274
return MultiScaleDeepEquilibriumNetwork(main_layers, mapping_layers, post_fuse_layer,
261275
solver, scales; init, kwargs...)
@@ -279,7 +293,8 @@ function MultiScaleNeuralODE(args...; kwargs...)
279293
end
280294

281295
## Generate Initial Condition
282-
@inline function __get_initial_condition(deq::DEQ{pType, Nothing}, x, ps, st) where {pType}
296+
@inline function __get_initial_condition(deq::DEQ{pType, NoOpLayer}, x, ps,
297+
st) where {pType}
283298
zₓ = __zeros_init(__getproperty(deq.model, Val(:scales)), x)
284299
z, st_ = deq.model((zₓ, x), ps.model, st.model)
285300
return z, (; st..., model=st_)

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ function __estimate_jacobian_trace(::AutoFiniteDiff, model, ps, z, x, rng)
9393
z[idx] = _z
9494
end
9595
end
96+
97+
return res
9698
end
9799

98100
__estimate_jacobian_trace(::Nothing, model, ps, z, x, rng) = zero(eltype(x))

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
99
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
1010
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
11+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1112
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
@@ -20,5 +21,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2021
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2122

2223
[compat]
23-
Aqua = "0.8"
24-
julia = "1.6"
24+
Aqua = "0.8"

test/layers.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@ using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiff
33

44
include("test_utils.jl")
55

6+
function loss_function(model, x, ps, st)
7+
y, st = model(x, ps, st)
8+
l1 = y isa Tuple ? sum(Base.Fix1(sum, abs2), y) : sum(abs2, y)
9+
l2 = st.solution.jacobian_loss
10+
l3 = sum(abs2, st.solution.z_star .- st.solution.u0)
11+
return l1 + l2 + l3
12+
end
13+
614
@testset "DeepEquilibriumNetwork" begin
715
rng = __get_prng(0)
816

@@ -22,6 +30,8 @@ include("test_utils.jl")
2230

2331
@testset "x_size: $(x_size)" for (base_model, init_model, x_size) in zip(base_models,
2432
init_models, x_sizes)
33+
@info solver, mtype, jacobian_regularization, base_model, init_model, x_size
34+
2535
model = if mtype === :deq
2636
DeepEquilibriumNetwork(base_model, solver; jacobian_regularization)
2737
elseif mtype === :skipdeq
@@ -47,6 +57,11 @@ include("test_utils.jl")
4757
@test st.solution isa DeepEquilibriumSolution
4858
@test maximum(abs, st.solution.residual) 1e-3
4959

60+
_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)
61+
62+
@test __is_finite_gradient(gs_x)
63+
@test __is_finite_gradient(gs_ps)
64+
5065
ps, st = Lux.setup(rng, model)
5166
st = Lux.update_state(st, :fixed_depth, Val(10))
5267
@test st.solution == DeepEquilibriumSolution()
@@ -58,6 +73,11 @@ include("test_utils.jl")
5873
@test size(z) == size(x)
5974
@test st.solution isa DeepEquilibriumSolution
6075
@test st.solution.nfe == 10
76+
77+
_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)
78+
79+
@test __is_finite_gradient(gs_x)
80+
@test __is_finite_gradient(gs_ps)
6181
end
6282
end
6383
end
@@ -91,11 +111,12 @@ end
91111
jacobian_regularizations = (nothing, AutoFiniteDiff(), AutoZygote())
92112

93113
for mtype in model_type, jacobian_regularization in jacobian_regularizations
94-
95114
@testset "Solver: $(__nameof(solver))" for solver in solvers
96-
97115
@testset "x_size: $(x_size)" for (main_layer, mapping_layer, init_layer, x_size, scale) in zip(main_layers,
98116
mapping_layers, init_layers, x_sizes, scales)
117+
@info solver, mtype, jacobian_regularization, main_layer, mapping_layer,
118+
init_layer, x_size, scale
119+
99120
model = if mtype === :deq
100121
MultiScaleDeepEquilibriumNetwork(main_layer, mapping_layer, nothing,
101122
solver,

test/test_utils.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,17 @@ __nameof(::X) where {X} = nameof(X)
55

66
__get_prng(seed::Int) = StableRNG(seed)
77

8-
# is_finite_gradient(x::AbstractArray) = all(isfinite, x)
8+
__is_finite_gradient(x::AbstractArray) = all(isfinite, x)
99

10-
# function is_finite_gradient(gs::NamedTuple)
11-
# gradient_is_finite = [true]
12-
# function _is_gradient_finite(x)
13-
# if !isnothing(x) && !all(isfinite, x)
14-
# gradient_is_finite[1] = false
15-
# end
16-
# return x
17-
# end
18-
# Functors.fmap(_is_gradient_finite, gs)
19-
# return gradient_is_finite[1]
20-
# end
10+
function __is_finite_gradient(gs::NamedTuple)
11+
gradient_is_finite = Ref(true)
12+
function __is_gradient_finite(x)
13+
!isnothing(x) && !all(isfinite, x) && (gradient_is_finite[] = false)
14+
return x
15+
end
16+
fmap(__is_gradient_finite, gs)
17+
return gradient_is_finite[]
18+
end
2119

2220
function __get_dense_layer(args...; kwargs...)
2321
init_weight(rng::AbstractRNG, dims...) = randn(rng, Float32, dims) .* 0.001f0

0 commit comments

Comments
 (0)