Skip to content

Commit 34932dc

Browse files
committed
Remove Unwanted Dependencies and rrules
1 parent 1e414cd commit 34932dc

File tree

8 files changed

+22
-75
lines changed

8 files changed

+22
-75
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ jobs:
2727
- ADJOINT
2828
version:
2929
- '1'
30-
- '1.6'
3130
steps:
3231
- uses: actions/checkout@v4
3332
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1212
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
13-
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1413
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
1514
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1615
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -22,15 +21,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2221
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
2322
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
2423
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
25-
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2624

2725
[compat]
2826
ChainRulesCore = "1"
2927
ConcreteStructs = "0.2"
3028
DiffEqBase = "6.119"
3129
LinearSolve = "1, 2"
32-
Lux = "0.4, 0.5"
33-
MLUtils = "0.2, 0.3, 0.4"
30+
Lux = "0.5.7"
3431
NonlinearSolve = "2"
3532
OrdinaryDiffEq = "6"
3633
Reexport = "1"
@@ -40,5 +37,4 @@ Setfield = "1"
4037
SteadyStateDiffEq = "1.16"
4138
TruncatedStacktraces = "1.1"
4239
Zygote = "0.6.34"
43-
ZygoteRules = "0.2"
44-
julia = "1.6"
40+
julia = "1.9"

src/DeepEquilibriumNetworks.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ import Reexport: @reexport
44

55
@reexport using Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity
66

7-
using DiffEqBase, LinearAlgebra, LinearSolve, MLUtils, Random, SciMLBase,
8-
Setfield, Statistics, SteadyStateDiffEq, Zygote
7+
using DiffEqBase,
8+
LinearAlgebra, LinearSolve, Random, SciMLBase, Statistics,
9+
SteadyStateDiffEq, Zygote
910

1011
import DiffEqBase: AbstractSteadyStateProblem
1112
import SciMLBase: AbstractNonlinearSolution, AbstractSteadyStateAlgorithm
@@ -34,14 +35,6 @@ include("layers/evaluate.jl")
3435

3536
include("chainrules.jl")
3637

37-
# Start of Weird Patches
38-
# Honestly no clue why this is needed! -- probably a whacky fix which shouldn't be ever
39-
# needed.
40-
using ZygoteRules
41-
ZygoteRules.gradtuple1(::NamedTuple{()}) = (nothing, nothing, nothing, nothing, nothing)
42-
ZygoteRules.gradtuple1(x::NamedTuple) = collect(values(x))
43-
# End of Weird Patches
44-
4538
# Useful Shorthand
4639
export DEQs
4740

src/chainrules.jl

Lines changed: 2 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,8 @@
1-
__backing::CRC.Tangent) = __backing(CRC.backing(Δ))
2-
__backing::Tuple) = __backing.(Δ)
3-
__backing::NamedTuple{F}) where {F} = NamedTuple{F}(__backing(values(Δ)))
4-
__backing(Δ) = Δ
5-
61
function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star::T, u0::T, residual::T,
72
jacobian_loss::R, nfe::Int) where {T, R <: AbstractFloat}
8-
function deep_equilibrium_solution_pullback(dsol)
3+
function ∇deep_equilibrium_solution(dsol)
94
return (∂∅, dsol.z_star, dsol.u0, dsol.residual, dsol.jacobian_loss, dsol.nfe)
105
end
116
return (DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe),
12-
deep_equilibrium_solution_pullback)
13-
end
14-
15-
function _safe_getfield(x::NamedTuple{fields}, field) where {fields}
16-
return field fields ? getfield(x, field) : ∂∅
17-
end
18-
19-
function CRC.rrule(::Type{T}, args...) where {T <: NamedTuple}
20-
y = T(args...)
21-
function nt_pullback(dy)
22-
fields = fieldnames(T)
23-
dy isa CRC.Tangent && (dy = CRC.backing(dy))
24-
return (∂∅, _safe_getfield.((dy,), fields)...)
25-
end
26-
return y, nt_pullback
27-
end
28-
29-
function CRC.rrule(::typeof(Setfield.set), obj, l::Setfield.PropertyLens{field},
30-
val) where {field}
31-
res = Setfield.set(obj, l, val)
32-
function setfield_pullback(Δres)
33-
Δres = __backing(Δres)
34-
Δobj = Setfield.set(obj, l, ∂∅)
35-
return (∂∅, Δobj, ∂∅, getfield(Δres, field))
36-
end
37-
return res, setfield_pullback
38-
end
39-
40-
function CRC.rrule(::typeof(_construct_problem), deq::AbstractDEQs, dudt, z,
41-
ps::NamedTuple{F}, x) where {F}
42-
prob = _construct_problem(deq, dudt, z, ps, x)
43-
function ∇_construct_problem(Δ)
44-
Δ = __backing(Δ)
45-
nograds = NamedTuple{F}(ntuple(i -> ∂∅, length(F)))
46-
return (∂∅, ∂∅, ∂∅, Δ.u0, merge(nograds, (; model=Δ.p.ps)), Δ.p.x)
47-
end
48-
return prob, ∇_construct_problem
7+
∇deep_equilibrium_solution)
498
end

src/layers/deq.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,10 @@ _jacobian_regularization(::SkipDeepEquilibriumNetwork{J}) where {J} = J
122122
function _get_initial_condition(deq::SkipDeepEquilibriumNetwork{J, M, Nothing}, x, ps,
123123
st) where {J, M}
124124
z, st_ = deq.model((zero(x), x), ps.model, st.model)
125-
@set! st.model = st_
126-
return z, st
125+
return z, merge(st, (; model=st_))
127126
end
128127

129128
function _get_initial_condition(deq::SkipDeepEquilibriumNetwork, x, ps, st)
130129
z, st_ = deq.shortcut(x, ps.shortcut, st.shortcut)
131-
@set! st.shortcut = st_
132-
return z, st
130+
return z, merge(st, (; shortcut=st_))
133131
end

src/layers/evaluate.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ function (deq::AbstractDEQs)(x::AbstractArray{T}, ps, st::NamedTuple, ::Val{true
2727
z_star, st_ = _evaluate_unrolled_model(deq, deq.model, z, x, ps.model, st.model,
2828
st.fixed_depth)
2929

30-
@set! st.model = st_
31-
@set! st.solution = build_solution(deq, z_star, z, x, ps, st, depth, T(0))
30+
st__ = merge(st,
31+
(; model=st_, solution=build_solution(deq, z_star, z, x, ps, st, depth, T(0))))
3232

33-
return _postprocess_output(deq, z_star), st
33+
return _postprocess_output(deq, z_star), st__
3434
end
3535

3636
function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})
@@ -60,9 +60,9 @@ function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})
6060
jac_loss = T(0)
6161
end
6262

63-
@set! st.model = model.st
64-
@set! st.solution = build_solution(deq, z_star, z, x, ps, st, nfe, jac_loss)
65-
@set! st.rng = rng
63+
st_ = merge(st,
64+
(; model=model.st, rng,
65+
solution=build_solution(deq, z_star, z, x, ps, st, nfe, jac_loss)))
6666

67-
return _postprocess_output(deq, z_star), st
67+
return _postprocess_output(deq, z_star), st_
6868
end

src/layers/mdeq.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,15 +250,13 @@ function _get_initial_condition(deq::MultiScaleSkipDeepEquilibriumNetwork{N, M,
250250
x, ps, st) where {N, M}
251251
u0, st = _get_zeros_initial_condition_mdeq(deq.scales, x, st)
252252
z, st_ = deq.model((u0, x), ps.model, st.model)
253-
@set! st.model = st_
254-
return z, st
253+
return z, merge(st, (; model=st_))
255254
end
256255

257256
function _get_initial_condition(deq::MultiScaleSkipDeepEquilibriumNetwork, x, ps, st)
258257
z0, st_ = deq.shortcut(x, ps.shortcut, st.shortcut)
259258
z = mapreduce(flatten, vcat, z0)
260-
@set! st.shortcut = st_
261-
return z, st
259+
return z, merge(st, (; shortcut=st_))
262260
end
263261

264262
@concrete struct MultiScaleNeuralODE{N} <: AbstractDeepEquilibriumNetwork

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,7 @@ DEQs.split_and_reshape(x, split_idxs, shapes)
3737
push!(calls, :(return tuple($(varnames...))))
3838
return Expr(:block, calls...)
3939
end
40+
41+
@inline flatten(x::AbstractVector) = reshape(x, length(x), 1)
42+
@inline flatten(x::AbstractMatrix) = x
43+
@inline flatten(x::AbstractArray) = reshape(x, :, size(x, ndims(x)))

0 commit comments

Comments
 (0)