Skip to content

Commit dda196c

Browse files
committed
oWe can now support non flat vector parameters
1 parent 71b0901 commit dda196c

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

src/chainrules.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@ function CRC.rrule(::typeof(Setfield.set), obj, l::Setfield.PropertyLens{field},
3737
return res, setfield_pullback
3838
end
3939

40-
# Honestly no clue why this is needed! -- probably a whacky fix which shouldn't be ever
41-
# needed.
42-
ZygoteRules.gradtuple1(::NamedTuple{()}) = (nothing, nothing, nothing, nothing, nothing)
43-
ZygoteRules.gradtuple1(x::NamedTuple) = collect(values(x))
40+
function CRC.rrule(::typeof(_construct_problem), deq::AbstractDEQs, dudt, z, ps, x)
41+
prob = _construct_problem(deq, dudt, z, ps, x)
42+
function ∇_construct_problem(Δ)
43+
return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), Δ.u0,
44+
(; model = Δ.p.ps), Δ.p.x)
45+
end
46+
return prob, ∇_construct_problem
47+
end

src/layers/evaluate.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ end
1313

1414
@inline _postprocess_output(_, z_star) = z_star
1515

16-
@inline function _construct_problem(::AbstractDEQs, dudt, z, ps)
17-
return SteadyStateProblem(ODEFunction{false}(dudt), z, ps.model)
16+
@inline function _construct_problem(::AbstractDEQs, dudt, z, ps, x)
17+
return SteadyStateProblem(ODEFunction{false}(dudt), z,
18+
NamedTuple{(:ps, :x)}((ps.model, x)))
1819
end
1920

2021
@inline _fix_solution_output(_, x) = x
@@ -42,14 +43,12 @@ function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})
4243

4344
function dudt(u, p, t)
4445
nfe += 1
45-
u_ = model((u, x), p)
46-
return u_ .- u
46+
return model((u, p.x), p.ps) .- u
4747
end
4848

49-
prob = _construct_problem(deq, dudt, z, ps)
49+
prob = _construct_problem(deq, dudt, z, ps, x)
5050
sol = solve(prob, deq.solver; deq.sensealg, deq.kwargs...)
51-
52-
z_star = model((_fix_solution_output(deq, sol.u), x), ps.model)
51+
z_star = sol.u
5352

5453
if _jacobian_regularization(deq)
5554
rng = Lux.replicate(st.rng)

src/layers/mdeq.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ function _get_initial_condition(deq::MultiScaleNeuralODE, x, ps, st)
347347
return _get_zeros_initial_condition_mdeq(deq.scales, x, st)
348348
end
349349

350-
@inline function _construct_problem(::MultiScaleNeuralODE, dudt, z, ps)
350+
@inline function _construct_problem(::MultiScaleNeuralODE, dudt, z, ps, x)
351351
return ODEProblem(ODEFunction{false}(dudt), z, (0.0f0, 1.0f0), ps.model)
352352
end
353353

0 commit comments

Comments
 (0)