|
36 | 36 | function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})
|
37 | 37 | T = eltype(x)
|
38 | 38 | z, st = _get_initial_condition(deq, x, ps, st)
|
39 |
| - st_, nfe = st.model, 0 |
| 39 | + |
| 40 | + model = Lux.Experimental.StatefulLuxLayer(deq.model, nothing, st.model) |
| 41 | + nfe::Int = 0 |
40 | 42 |
|
41 | 43 | function dudt(u, p, t)
|
42 | 44 | nfe += 1
|
43 |
| - u_, st_ = deq.model((u, x), p, st_) |
| 45 | + u_ = model((u, x), p) |
44 | 46 | return u_ .- u
|
45 | 47 | end
|
46 | 48 |
|
47 | 49 | prob = _construct_problem(deq, dudt, z, ps)
|
48 | 50 | sol = solve(prob, deq.solver; deq.sensealg, deq.kwargs...)
|
49 | 51 |
|
50 |
| - z_star, st_ = deq.model((_fix_solution_output(deq, sol.u), x), ps.model, st_) |
| 52 | + z_star = model((_fix_solution_output(deq, sol.u), x), ps.model) |
51 | 53 |
|
52 | 54 | if _jacobian_regularization(deq)
|
53 | 55 | rng = Lux.replicate(st.rng)
|
54 |
| - jac_loss = estimate_jacobian_trace(Val(:finite_diff), deq.model, ps.model, st.model, |
| 56 | + jac_loss = estimate_jacobian_trace(Val(:finite_diff), deq.model, ps.model, model.st, |
55 | 57 | z_star, x, rng)
|
56 | 58 | else
|
57 | 59 | rng = st.rng
|
58 | 60 | jac_loss = T(0)
|
59 | 61 | end
|
60 | 62 |
|
61 |
| - @set! st.model = st_ |
| 63 | + @set! st.model = model.st |
62 | 64 | @set! st.solution = build_solution(deq, z_star, z, x, ps, st, nfe, jac_loss)
|
63 | 65 | @set! st.rng = rng
|
64 | 66 |
|
|
0 commit comments