Skip to content

Commit 1903cfa

Browse files
committed
Allow debug mode
1 parent dda196c commit 1903cfa

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

src/layers/mdeq.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
u, x = z
1818
u_ = split_and_reshape(u, m.split_idxs, m.scales)
1919
u_res, st = m.model(($(inputs...),), ps, st)
20-
return vcat(flatten.(u_res)...), st
20+
return mapreduce(flatten, vcat, u_res), st
2121
end
2222
end
2323

@@ -80,6 +80,10 @@ See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref)
8080
kwargs
8181
end
8282

83+
function MultiScaleDeepEquilibriumNetwork(model::MultiScaleInputLayer{N}, args...) where {N}
84+
return MultiScaleDeepEquilibriumNetwork{N}(model, args...)
85+
end
86+
8387
@truncate_stacktrace MultiScaleDeepEquilibriumNetwork 1 3
8488

8589
function Lux.initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwork)
@@ -104,7 +108,7 @@ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Ma
104108
split_idxs, scales)
105109
end
106110

107-
return MultiScaleDeepEquilibriumNetwork{N}(model, solver, sensealg, scales, split_idxs,
111+
return MultiScaleDeepEquilibriumNetwork(model, solver, sensealg, scales, split_idxs,
108112
kwargs)
109113
end
110114

@@ -205,6 +209,11 @@ See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref)
205209
kwargs
206210
end
207211

212+
function MultiScaleSkipDeepEquilibriumNetwork(model::MultiScaleInputLayer{N},
213+
args...) where {N}
214+
return MultiScaleSkipDeepEquilibriumNetwork{N}(model, args...)
215+
end
216+
208217
@truncate_stacktrace MultiScaleSkipDeepEquilibriumNetwork 1 3 4
209218

210219
function Lux.initialstates(rng::AbstractRNG, deq::MultiScaleSkipDeepEquilibriumNetwork)
@@ -231,7 +240,7 @@ function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers
231240
split_idxs, scales)
232241
end
233242

234-
return MultiScaleSkipDeepEquilibriumNetwork{N}(model, shortcut, solver, sensealg,
243+
return MultiScaleSkipDeepEquilibriumNetwork(model, shortcut, solver, sensealg,
235244
scales, split_idxs, kwargs)
236245
end
237246

src/solve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ end
8080

8181
@truncate_stacktrace EquilibriumSolution 1 2
8282

83-
function DiffEqBase.__solve(prob::AbstractSteadyStateProblem, alg::AbstractDEQSolver,
83+
function SciMLBase.__solve(prob::AbstractSteadyStateProblem, alg::AbstractDEQSolver,
8484
args...; kwargs...)
85+
# FIXME: Remove this handle
8586
sol = solve(prob, alg.alg, args...; kwargs...)
8687

8788
u, du, retcode = sol.u, sol.resid, sol.retcode

0 commit comments

Comments
 (0)