Skip to content

Commit d924050

Browse files
committed
Test Nonlinear Solve solvers
1 parent 432f7b0 commit d924050

File tree

6 files changed

+30
-27
lines changed

6 files changed

+30
-27
lines changed

src/layers/core.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple)
3030
end
3131

3232
# Utilities
33-
@inline _check_unrolled_mode(::Val{d}) where {d} = Val(d >= 1)
33+
@inline _check_unrolled_mode(::Val{d}) where {d} = Val(d 1)
3434
@inline _check_unrolled_mode(st::NamedTuple) = _check_unrolled_mode(st.fixed_depth)
3535

3636
@inline _get_unrolled_depth(::Val{d}) where {d} = d

src/layers/mdeq.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,7 @@ function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
335335
model = MultiScaleInputLayer(Chain(l1, l2), split_idxs, scales)
336336
else
337337
model = MultiScaleInputLayer(Chain(l1, l2, Parallel(nothing, post_fuse_layer...)),
338-
split_idxs,
339-
scales)
338+
split_idxs, scales)
340339
end
341340

342341
return MultiScaleNeuralODE{N}(model, solver, sensealg, scales, split_idxs, kwargs)
@@ -355,8 +354,7 @@ end
355354
@inline _fix_solution_output(::MultiScaleNeuralODE, x) = x[end]
356355

357356
# Shared Functions
358-
@generated function _get_zeros_initial_condition_mdeq(::S,
359-
x::AbstractArray{T, N},
357+
@generated function _get_zeros_initial_condition_mdeq(::S, x::AbstractArray{T, N},
360358
st::NamedTuple{fields}) where {S, T, N, fields}
361359
scales = known(S)
362360
sz = sum(prod.(scales))

src/solve.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,16 @@ more flexibility needed for solving DEQ problems.
5151
5252
See also: [`ContinuousDEQSolver`](@ref)
5353
"""
54-
Base.@kwdef @concrete struct DiscreteDEQSolver <: AbstractDEQSolver
55-
alg = LBroyden(; batched=true,
56-
termination_condition=NLSolveTerminationCondition(NLSolveTerminationMode.RelSafe;
57-
abstol=1.0f-8, reltol=1.0f-6))
54+
struct DiscreteDEQSolver{A} <: AbstractDEQSolver
55+
alg::A
56+
function DiscreteDEQSolver(alg=nothing)
57+
if alg === nothing
58+
alg = LBroyden(; batched=true,
59+
termination_condition=NLSolveTerminationCondition(NLSolveTerminationMode.RelSafe;
60+
abstol=1.0f-8, reltol=1.0f-6))
61+
end
62+
return new{typeof(alg)}(alg)
63+
end
5864
end
5965

6066
"""
@@ -78,9 +84,7 @@ function DiffEqBase.__solve(prob::AbstractSteadyStateProblem, alg::AbstractDEQSo
7884
args...; kwargs...)
7985
sol = solve(prob, alg.alg, args...; kwargs...)
8086

81-
# This is not necessarily true and might fail. But makes the code type stable
82-
u = sol.u::typeof(prob.u0)
83-
du, retcode = sol.resid, sol.retcode
87+
u, du, retcode = sol.u, sol.resid, sol.retcode
8488

8589
return EquilibriumSolution{eltype(u), ndims(u)}(u, du, prob, alg, retcode)
8690
end

test/layers/deq.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,12 @@ function test_skip_deep_equilibrium_network_v2()
122122
return nothing
123123
end
124124

125-
Test.@testset "DeepEquilibriumNetwork" begin
125+
@testset "DeepEquilibriumNetwork" begin
126126
test_deep_equilibrium_network()
127127
end
128-
Test.@testset "SkipDeepEquilibriumNetwork" begin
128+
@testset "SkipDeepEquilibriumNetwork" begin
129129
test_skip_deep_equilibrium_network()
130130
end
131-
Test.@testset "SkipRegDeepEquilibriumNetwork" begin
131+
@testset "SkipRegDeepEquilibriumNetwork" begin
132132
test_skip_deep_equilibrium_network_v2()
133133
end

test/solve.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,32 @@ using Test
44
simple_dudt(u, p, t) = 0.9f0 .* u .- u
55

66
function test_continuous_deq_solver()
7-
prob = SteadyStateProblem(simple_dudt, [1.0f0], SciMLBase.NullParameters())
7+
prob = SteadyStateProblem(simple_dudt, [1.0f0])
88

99
sol = solve(prob, ContinuousDEQSolver(); save_everystep=true)
1010

1111
@test sol isa DEQs.EquilibriumSolution
12-
@test abs(sol.u[1]) <= 1.0f-4
12+
@test abs(sol.u[1]) 1.0f-4
1313

1414
return nothing
1515
end
1616

17-
function test_discrete_deq_solver()
18-
prob = SteadyStateProblem(simple_dudt, reshape([1.0f0], 1, 1),
19-
SciMLBase.NullParameters())
17+
function test_discrete_deq_solver(; solver=nothing)
18+
prob = SteadyStateProblem(simple_dudt, reshape([1.0f0], 1, 1))
2019

21-
sol = solve(prob, DiscreteDEQSolver())
20+
sol = solve(prob, solver === nothing ? DiscreteDEQSolver() : DiscreteDEQSolver(solver))
2221

2322
@test sol isa DEQs.EquilibriumSolution
24-
@test abs(sol.u[1]) <= 1.0f-4
23+
@test abs(sol.u[1]) 1.0f-4
2524

2625
return nothing
2726
end
2827

29-
Test.@testset "Continuous Steady State Solve" begin
28+
@testset "Continuous Steady State Solve" begin
3029
test_continuous_deq_solver()
3130
end
32-
Test.@testset "Discrete Steady State Solve" begin
33-
test_discrete_deq_solver()
31+
@testset "Discrete Steady State Solve" begin
32+
test_discrete_deq_solver(; solver=nothing) # Default
33+
test_discrete_deq_solver(; solver=NewtonRaphson())
34+
test_discrete_deq_solver(; solver=LevenbergMarquardt())
3435
end

test/test_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ catch
1313
end
1414

1515
function get_prng(seed::Int)
16-
@static if VERSION >= v"1.7"
16+
@static if VERSION v"1.7"
1717
rng = Xoshiro()
1818
Random.seed!(rng, seed)
1919
return rng
@@ -39,7 +39,7 @@ function is_finite_gradient(gs::NamedTuple)
3939
end
4040

4141
function run_JET_tests(f, args...; call_broken=false, opt_broken=false, kwargs...)
42-
@static if VERSION >= v"1.7"
42+
@static if VERSION v"1.7"
4343
test_call(f, typeof.(args); broken=call_broken, target_modules=(DEQs,))
4444
test_opt(f, typeof.(args); broken=opt_broken, target_modules=(DEQs,))
4545
end

0 commit comments

Comments
 (0)