@@ -94,8 +94,9 @@ function Lux.initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwo
94
94
end
95
95
96
96
function MultiScaleDeepEquilibriumNetwork (main_layers:: Tuple , mapping_layers:: Matrix ,
97
- post_fuse_layer:: Union{Nothing, Tuple} , solver, scales:: NTuple{N, NTuple{L, Int64}} ;
98
- sensealg= SteadyStateAdjoint (), kwargs... ) where {N, L}
97
+ post_fuse_layer:: Union{Nothing, Tuple} , solver,
98
+ scales:: Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}} ;
99
+ sensealg= SteadyStateAdjoint (), kwargs... ) where {nMinus1, L}
99
100
l1 = Parallel (nothing , main_layers... )
100
101
l2 = BranchLayer (Parallel .(+ , map (x -> tuple (x... ), eachrow (mapping_layers))... )... )
101
102
226
227
227
228
function MultiScaleSkipDeepEquilibriumNetwork (main_layers:: Tuple , mapping_layers:: Matrix ,
228
229
post_fuse_layer:: Union{Nothing, Tuple} , shortcut_layers:: Union{Nothing, Tuple} ,
229
- solver, scales:: NTuple{N, NTuple{L, Int64}} ;
230
- sensealg= SteadyStateAdjoint (), kwargs... ) where {N, L}
230
+ solver,
231
+ scales:: Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}} ;
232
+ sensealg= SteadyStateAdjoint (), kwargs... ) where {nMinus1, L}
231
233
l1 = Parallel (nothing , main_layers... )
232
234
l2 = BranchLayer (Parallel .(+ , map (x -> tuple (x... ), eachrow (mapping_layers))... )... )
233
235
shortcut = shortcut_layers === nothing ? nothing : Parallel (nothing , shortcut_layers... )
@@ -330,8 +332,9 @@ model(x, ps, st)
330
332
See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref)
331
333
"""
332
334
function MultiScaleNeuralODE (main_layers:: Tuple , mapping_layers:: Matrix ,
333
- post_fuse_layer:: Union{Nothing, Tuple} , solver, scales:: NTuple{N, NTuple{L, Int64}} ;
334
- sensealg= GaussAdjoint (; autojacvec= ZygoteVJP ()), kwargs... ) where {N, L}
335
+ post_fuse_layer:: Union{Nothing, Tuple} , solver,
336
+ scales:: Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}} ;
337
+ sensealg= GaussAdjoint (; autojacvec= ZygoteVJP ()), kwargs... ) where {nMinus1, L}
335
338
l1 = Parallel (nothing , main_layers... )
336
339
l2 = BranchLayer (Parallel .(+ , map (x -> tuple (x... ), eachrow (mapping_layers))... )... )
337
340
@@ -344,7 +347,7 @@ function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
344
347
split_idxs, scales)
345
348
end
346
349
347
- return MultiScaleNeuralODE {N } (model, solver, sensealg, scales, split_idxs, kwargs)
350
+ return MultiScaleNeuralODE {nMinus1+1 } (model, solver, sensealg, scales, split_idxs, kwargs)
348
351
end
349
352
350
353
_jacobian_regularization (:: MultiScaleNeuralODE ) = false
0 commit comments