Skip to content

Commit 42a35fe

Browse files
committed
fix unbound type parameters in NTuple
1 parent 25eece3 commit 42a35fe

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

src/layers/mdeq.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ function Lux.initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwo
9494
end
9595

9696
function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
97-
post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}};
98-
sensealg=SteadyStateAdjoint(), kwargs...) where {N, L}
97+
post_fuse_layer::Union{Nothing, Tuple}, solver,
98+
scales::Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}};
99+
sensealg=SteadyStateAdjoint(), kwargs...) where {nMinus1, L}
99100
l1 = Parallel(nothing, main_layers...)
100101
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
101102

@@ -226,8 +227,9 @@ end
226227

227228
function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
228229
post_fuse_layer::Union{Nothing, Tuple}, shortcut_layers::Union{Nothing, Tuple},
229-
solver, scales::NTuple{N, NTuple{L, Int64}};
230-
sensealg=SteadyStateAdjoint(), kwargs...) where {N, L}
230+
solver,
231+
scales::Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}};
232+
sensealg=SteadyStateAdjoint(), kwargs...) where {nMinus1, L}
231233
l1 = Parallel(nothing, main_layers...)
232234
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
233235
shortcut = shortcut_layers === nothing ? nothing : Parallel(nothing, shortcut_layers...)
@@ -330,8 +332,9 @@ model(x, ps, st)
330332
See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref), [`MultiScaleDeepEquilibriumNetwork`](@ref), [`MultiScaleSkipDeepEquilibriumNetwork`](@ref)
331333
"""
332334
function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
333-
post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}};
334-
sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {N, L}
335+
post_fuse_layer::Union{Nothing, Tuple}, solver,
336+
scales::Tuple{NTuple{L, Int64}, Vararg{NTuple{L, Int64}, nMinus1}};
337+
sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {nMinus1, L}
335338
l1 = Parallel(nothing, main_layers...)
336339
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)
337340

@@ -344,7 +347,7 @@ function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
344347
split_idxs, scales)
345348
end
346349

347-
return MultiScaleNeuralODE{N}(model, solver, sensealg, scales, split_idxs, kwargs)
350+
return MultiScaleNeuralODE{nMinus1+1}(model, solver, sensealg, scales, split_idxs, kwargs)
348351
end
349352

350353
_jacobian_regularization(::MultiScaleNeuralODE) = false

test/qa.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ using DeepEquilibriumNetworks, Aqua
66
Aqua.test_piracies(DeepEquilibriumNetworks; broken=true)
77
Aqua.test_project_extras(DeepEquilibriumNetworks)
88
Aqua.test_stale_deps(DeepEquilibriumNetworks)
9-
Aqua.test_unbound_args(DeepEquilibriumNetworks; broken=true)
9+
Aqua.test_unbound_args(DeepEquilibriumNetworks)
1010
Aqua.test_undefined_exports(DeepEquilibriumNetworks; broken=true)
1111
end

0 commit comments

Comments
 (0)