Skip to content

Commit ce77355

Browse files
committed
Use Lux Stateful layer for type stability
1 parent d924050 commit ce77355

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/layers/core.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ const AbstractDEQs = Union{AbstractDeepEquilibriumNetwork,
2626
AbstractSkipDeepEquilibriumNetwork}
2727

2828
function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple)
29+
# return deq
2930
return deq(x, ps, st, _check_unrolled_mode(st))
3031
end
3132

src/layers/evaluate.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,29 +36,31 @@ end
3636
function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})
3737
T = eltype(x)
3838
z, st = _get_initial_condition(deq, x, ps, st)
39-
st_, nfe = st.model, 0
39+
40+
model = Lux.Experimental.StatefulLuxLayer(deq.model, nothing, st.model)
41+
nfe::Int = 0
4042

4143
function dudt(u, p, t)
4244
nfe += 1
43-
u_, st_ = deq.model((u, x), p, st_)
45+
u_ = model((u, x), p)
4446
return u_ .- u
4547
end
4648

4749
prob = _construct_problem(deq, dudt, z, ps)
4850
sol = solve(prob, deq.solver; deq.sensealg, deq.kwargs...)
4951

50-
z_star, st_ = deq.model((_fix_solution_output(deq, sol.u), x), ps.model, st_)
52+
z_star = model((_fix_solution_output(deq, sol.u), x), ps.model)
5153

5254
if _jacobian_regularization(deq)
5355
rng = Lux.replicate(st.rng)
54-
jac_loss = estimate_jacobian_trace(Val(:finite_diff), deq.model, ps.model, st.model,
56+
jac_loss = estimate_jacobian_trace(Val(:finite_diff), deq.model, ps.model, model.st,
5557
z_star, x, rng)
5658
else
5759
rng = st.rng
5860
jac_loss = T(0)
5961
end
6062

61-
@set! st.model = st_
63+
@set! st.model = model.st
6264
@set! st.solution = build_solution(deq, z_star, z, x, ps, st, nfe, jac_loss)
6365
@set! st.rng = rng
6466

0 commit comments

Comments
 (0)