17
17
u, x = z
18
18
u_ = split_and_reshape (u, m. split_idxs, m. scales)
19
19
u_res, st = m. model (($ (inputs... ),), ps, st)
20
- return vcat (flatten .( u_res) ... ), st
20
+ return mapreduce (flatten, vcat, u_res), st
21
21
end
22
22
end
23
23
@@ -80,6 +80,10 @@ See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref)
80
80
kwargs
81
81
end
82
82
83
+ function MultiScaleDeepEquilibriumNetwork (model:: MultiScaleInputLayer{N} , args... ) where {N}
84
+ return MultiScaleDeepEquilibriumNetwork {N} (model, args... )
85
+ end
86
+
83
87
@truncate_stacktrace MultiScaleDeepEquilibriumNetwork 1 3
84
88
85
89
function Lux. initialstates (rng:: AbstractRNG , deq:: MultiScaleDeepEquilibriumNetwork )
@@ -104,7 +108,7 @@ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Ma
104
108
split_idxs, scales)
105
109
end
106
110
107
- return MultiScaleDeepEquilibriumNetwork {N} (model, solver, sensealg, scales, split_idxs,
111
+ return MultiScaleDeepEquilibriumNetwork (model, solver, sensealg, scales, split_idxs,
108
112
kwargs)
109
113
end
110
114
@@ -205,6 +209,11 @@ See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref)
205
209
kwargs
206
210
end
207
211
212
+ function MultiScaleSkipDeepEquilibriumNetwork (model:: MultiScaleInputLayer{N} ,
213
+ args... ) where {N}
214
+ return MultiScaleSkipDeepEquilibriumNetwork {N} (model, args... )
215
+ end
216
+
208
217
@truncate_stacktrace MultiScaleSkipDeepEquilibriumNetwork 1 3 4
209
218
210
219
function Lux. initialstates (rng:: AbstractRNG , deq:: MultiScaleSkipDeepEquilibriumNetwork )
@@ -231,7 +240,7 @@ function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers
231
240
split_idxs, scales)
232
241
end
233
242
234
- return MultiScaleSkipDeepEquilibriumNetwork {N} (model, shortcut, solver, sensealg,
243
+ return MultiScaleSkipDeepEquilibriumNetwork (model, shortcut, solver, sensealg,
235
244
scales, split_idxs, kwargs)
236
245
end
237
246
0 commit comments