Skip to content

Commit 0378e1d

Browse files
committed
Final few tests
1 parent d4f6903 commit 0378e1d

File tree

4 files changed

+30
-22
lines changed

4 files changed

+30
-22
lines changed

docs/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
DeepEquilibriumNetworks = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
33
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
44
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
5+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
6+
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
7+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
10+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
511

612
[compat]
713
DeepEquilibriumNetworks = "2"

docs/src/index.md

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# DeepEquilibriumNetworks: (Fast) Deep Equilibrium Networks
1+
# DeepEquilibriumNetworks.jl
22

33
DeepEquilibriumNetworks.jl is a framework built on top of
44
[DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/) and
@@ -16,8 +16,8 @@ Pkg.add("DeepEquilibriumNetworks")
1616

1717
## Quick-start
1818

19-
```julia
20-
using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote
19+
```@example
20+
using DeepEquilibriumNetworks, Lux, Random, NonlinearSolve, Zygote, SciMLSensitivity
2121
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support
2222
2323
seed = 0
@@ -46,14 +46,11 @@ If you are using this project for research or other academic purposes, consider
4646
paper:
4747

4848
```bibtex
49-
@misc{pal2022mixing,
50-
title={Mixing Implicit and Explicit Deep Learning with Skip DEQs and Infinite Time Neural
51-
ODEs (Continuous DEQs)},
52-
author={Avik Pal and Alan Edelman and Christopher Rackauckas},
53-
year={2022},
54-
eprint={2201.12240},
55-
archivePrefix={arXiv},
56-
primaryClass={cs.LG}
49+
@article{pal2022continuous,
50+
title={Continuous Deep Equilibrium Models: Training Neural ODEs Faster by Integrating Them to Infinity},
51+
author={Pal, Avik and Edelman, Alan and Rackauckas, Christopher},
52+
booktitle={2023 IEEE High Performance Extreme Computing Conference (HPEC)},
53+
year={2023}
5754
}
5855
```
5956

src/layers.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,8 @@ function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star, u0, residual, jaco
2727
sol = DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe, original)
2828
∇DeepEquilibriumSolution(::CRC.NoTangent) = ntuple(_ -> CRC.NoTangent(), 7)
2929
function ∇DeepEquilibriumSolution(∂sol)
30-
∂z_star = ∂sol.z_star
31-
∂u0 = ∂sol.u0
32-
∂residual = ∂sol.residual
33-
∂jacobian_loss = ∂sol.jacobian_loss
34-
∂nfe = ∂sol.nfe
35-
∂original = CRC.NoTangent()
36-
return (CRC.NoTangent(), ∂z_star, ∂u0, ∂residual, ∂jacobian_loss, ∂nfe, ∂original)
30+
return (CRC.NoTangent(), ∂sol.z_star, ∂sol.u0, ∂sol.residual, ∂sol.jacobian_loss,
31+
∂sol.nfe, CRC.NoTangent())
3732
end
3833
return sol, ∇DeepEquilibriumSolution
3934
end
@@ -149,11 +144,11 @@ Deep Equilibrium Network as proposed in [baideep2019](@cite) and [pal2022mixing]
149144
150145
## Example
151146
152-
```julia
147+
```@example
153148
using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
154149
155150
model = DeepEquilibriumNetwork(Parallel(+, Dense(2, 2; use_bias=false),
156-
Dense(2, 2; use_bias=false)), VCABM3(); save_everystep=true)
151+
Dense(2, 2; use_bias=false)), VCABM3())
157152
158153
rng = Random.default_rng()
159154
ps, st = Lux.setup(rng, model)
@@ -218,7 +213,7 @@ For keyword arguments, see [`DeepEquilibriumNetwork`](@ref).
218213
219214
## Example
220215
221-
```julia
216+
```@example
222217
using DeepEquilibriumNetworks, Lux, Random, OrdinaryDiffEq
223218
224219
main_layers = (Parallel(+, Dense(4 => 4, tanh; use_bias=false),

test/layers.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using ADTypes, DeepEquilibriumNetworks, DiffEqBase, NonlinearSolve, OrdinaryDiffEq,
2-
SciMLBase, Test, NLsolve
2+
SciMLSensitivity, SciMLBase, Test, NLsolve
33

44
include("test_utils.jl")
55

@@ -152,6 +152,11 @@ end
152152
@test maximum(abs, st.solution.residual) 1e-3
153153
end
154154

155+
_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)
156+
157+
@test __is_finite_gradient(gs_x)
158+
@test __is_finite_gradient(gs_ps)
159+
155160
ps, st = Lux.setup(rng, model)
156161
st = Lux.update_state(st, :fixed_depth, Val(10))
157162
@test st.solution == DeepEquilibriumSolution()
@@ -165,6 +170,11 @@ end
165170
@test size(z_) == (sum(prod, scale), size(x, ndims(x)))
166171
@test st.solution isa DeepEquilibriumSolution
167172
@test st.solution.nfe == 10
173+
174+
_, gs_x, gs_ps, _ = Zygote.gradient(loss_function, model, x, ps, st)
175+
176+
@test __is_finite_gradient(gs_x)
177+
@test __is_finite_gradient(gs_ps)
168178
end
169179
end
170180
end

0 commit comments

Comments
 (0)