From 62d29be340bf8d9706a8ecce2e7910f1f8484a0d Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Mon, 7 Apr 2025 03:01:40 +0800 Subject: [PATCH 1/2] Fix SimpleNonlinearSolve CI --- lib/BracketingNonlinearSolve/src/muller.jl | 32 +++++++++---------- .../test/muller_tests.jl | 18 +++++------ .../test/rootfind_tests.jl | 6 ++-- lib/SimpleNonlinearSolve/Project.toml | 1 + lib/SimpleNonlinearSolve/src/halley.jl | 2 +- lib/SimpleNonlinearSolve/src/utils.jl | 5 +-- 6 files changed, 34 insertions(+), 30 deletions(-) diff --git a/lib/BracketingNonlinearSolve/src/muller.jl b/lib/BracketingNonlinearSolve/src/muller.jl index edde280de..fbd87c34d 100644 --- a/lib/BracketingNonlinearSolve/src/muller.jl +++ b/lib/BracketingNonlinearSolve/src/muller.jl @@ -8,8 +8,8 @@ initial guesses `(left, middle, right)` for the root. ### Keyword Arguments -- `middle`: the initial guess for the middle point. If not provided, the - midpoint of the interval `(left, right)` is used. + - `middle`: the initial guess for the middle point. If not provided, the + midpoint of the interval `(left, right)` is used. """ struct Muller{T} <: AbstractBracketingAlgorithm middle::T @@ -18,7 +18,7 @@ end Muller() = Muller(nothing) function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Muller, args...; - abstol = nothing, maxiters = 1000, kwargs...) + abstol = nothing, maxiters = 1000, kwargs...) @assert !SciMLBase.isinplace(prob) "`Muller` only supports out-of-place problems." xᵢ₋₂, xᵢ = prob.tspan xᵢ₋₁ = isnothing(alg.middle) ? (xᵢ₋₂ + xᵢ) / 2 : alg.middle @@ -32,19 +32,19 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Muller, args...; abstol = abs(NonlinearSolveBase.get_tolerance( xᵢ₋₂, abstol, promote_type(eltype(xᵢ₋₂), eltype(xᵢ)))) - for _ ∈ 1:maxiters - q = (xᵢ - xᵢ₋₁)/(xᵢ₋₁ - xᵢ₋₂) - A = q*fxᵢ - q*(1 + q)*fxᵢ₋₁ + q^2*fxᵢ₋₂ - B = (2*q + 1)*fxᵢ - (1 + q)^2*fxᵢ₋₁ + q^2*fxᵢ₋₂ - C = (1 + q)*fxᵢ + for _ in 1:maxiters + q = (xᵢ - xᵢ₋₁) / (xᵢ₋₁ - xᵢ₋₂) + A = q * fxᵢ - q * (1 + q) * fxᵢ₋₁ + q^2 * fxᵢ₋₂ + B = (2 * q + 1) * fxᵢ - (1 + q)^2 * fxᵢ₋₁ + q^2 * fxᵢ₋₂ + C = (1 + q) * fxᵢ - denom₊ = B + √(B^2 - 4*A*C) - denom₋ = B - √(B^2 - 4*A*C) + denom₊ = B + √(B^2 - 4 * A * C) + denom₋ = B - √(B^2 - 4 * A * C) if abs(denom₊) ≥ abs(denom₋) - xᵢ₊₁ = xᵢ - (xᵢ - xᵢ₋₁)*2*C/denom₊ + xᵢ₊₁ = xᵢ - (xᵢ - xᵢ₋₁) * 2 * C / denom₊ else - xᵢ₊₁ = xᵢ - (xᵢ - xᵢ₋₁)*2*C/denom₋ + xᵢ₊₁ = xᵢ - (xᵢ - xᵢ₋₁) * 2 * C / denom₋ end fxᵢ₊₁ = f(xᵢ₊₁) @@ -52,8 +52,8 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Muller, args...; # Termination Check if abstol ≥ abs(fxᵢ₊₁) return SciMLBase.build_solution(prob, alg, xᵢ₊₁, fxᵢ₊₁; - retcode = ReturnCode.Success, - left = xᵢ₊₁, right = xᵢ₊₁) + retcode = ReturnCode.Success, + left = xᵢ₊₁, right = xᵢ₊₁) end xᵢ₋₂, xᵢ₋₁, xᵢ = xᵢ₋₁, xᵢ, xᵢ₊₁ @@ -61,6 +61,6 @@ function CommonSolve.solve(prob::IntervalNonlinearProblem, alg::Muller, args...; end return SciMLBase.build_solution(prob, alg, xᵢ₊₁, fxᵢ₊₁; - retcode = ReturnCode.MaxIters, - left = xᵢ₊₁, right = xᵢ₊₁) + retcode = ReturnCode.MaxIters, + left = xᵢ₊₁, right = xᵢ₊₁) end diff --git a/lib/BracketingNonlinearSolve/test/muller_tests.jl b/lib/BracketingNonlinearSolve/test/muller_tests.jl index f4800b9d4..472e1ec7d 100644 --- a/lib/BracketingNonlinearSolve/test/muller_tests.jl +++ b/lib/BracketingNonlinearSolve/test/muller_tests.jl @@ -1,7 +1,7 @@ @testitem "Muller" begin f(u, p) = u^2 - p g(u, p) = sin(u) - h(u, p) = exp(-u)*sin(u) + h(u, p) = exp(-u) * sin(u) i(u, p) = u^3 - 1 @testset "Quadratic function" begin @@ -30,7 +30,7 @@ prob = IntervalNonlinearProblem{false}(g, tspan) sol = solve(prob, Muller()) - @test sol.u ≈ 2*π + @test sol.u ≈ 2 * π end @testset "Exponential-sine function" begin @@ -44,7 +44,7 @@ prob = IntervalNonlinearProblem{false}(h, tspan) sol = solve(prob, Muller()) - @test sol.u ≈ 0 atol = 1e-15 + @test sol.u≈0 atol=1e-15 tspan = (-1.0, 1.0) prob = IntervalNonlinearProblem{false}(h, tspan) @@ -54,17 +54,17 @@ end @testset "Complex roots" begin - tspan = (-1.0, 1.0*im) + tspan = (-1.0, 1.0 * im) prob = IntervalNonlinearProblem{false}(i, tspan) sol = solve(prob, Muller()) - @test sol.u ≈ (-1 + √3*im)/2 + @test sol.u ≈ (-1 + √3 * im) / 2 - tspan = (-1.0, -1.0*im) + tspan = (-1.0, -1.0 * im) prob = IntervalNonlinearProblem{false}(i, tspan) sol = solve(prob, Muller()) - @test sol.u ≈ (-1 - √3*im)/2 + @test sol.u ≈ (-1 - √3 * im) / 2 end @testset "Middle" begin @@ -87,10 +87,10 @@ @test sol.u ≈ -π - tspan = (-1.0, 1.0*im) + tspan = (-1.0, 1.0 * im) prob = IntervalNonlinearProblem{false}(i, tspan) sol = solve(prob, Muller(0.0)) - @test sol.u ≈ (-1 + √3*im)/2 + @test sol.u ≈ (-1 + √3 * im) / 2 end end diff --git a/lib/BracketingNonlinearSolve/test/rootfind_tests.jl b/lib/BracketingNonlinearSolve/test/rootfind_tests.jl index e8666a8c1..e32c37df0 100644 --- a/lib/BracketingNonlinearSolve/test/rootfind_tests.jl +++ b/lib/BracketingNonlinearSolve/test/rootfind_tests.jl @@ -7,7 +7,8 @@ end @testitem "Interval Nonlinear Problems" setup=[RootfindingTestSnippet] tags=[:core] begin using ForwardDiff - @testset for alg in (Alefeld(), Bisection(), Brent(), Falsi(), ITP(), Muller(), Ridder(), nothing) + @testset for alg in ( + Alefeld(), Bisection(), Brent(), Falsi(), ITP(), Muller(), Ridder(), nothing) tspan = (1.0, 20.0) function g(p) @@ -76,7 +77,8 @@ end end @testitem "Flipped Signs and Reversed Tspan" setup=[RootfindingTestSnippet] tags=[:core] begin - @testset for alg in (Alefeld(), Bisection(), Brent(), Falsi(), ITP(), Muller(), Ridder(), nothing) + @testset for alg in ( + Alefeld(), Bisection(), Brent(), Falsi(), ITP(), Muller(), Ridder(), nothing) f1(u, p) = u * u - p f2(u, p) = p - u * u diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index 199085937..e547c531e 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -20,6 +20,7 @@ NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLJacobianOperators = "19f34311-ddf3-4b8b-af20-060888a46c0e" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" diff --git a/lib/SimpleNonlinearSolve/src/halley.jl b/lib/SimpleNonlinearSolve/src/halley.jl index 773f4b569..1e00c2234 100644 --- a/lib/SimpleNonlinearSolve/src/halley.jl +++ b/lib/SimpleNonlinearSolve/src/halley.jl @@ -74,7 +74,7 @@ function SciMLBase.__solve( end aᵢ = J_fact \ NLBUtils.safe_vec(fx) - hvvp = Utils.compute_hvvp(prob, autodiff, fx_cache, x, aᵢ) + hvvp = Utils.compute_hvvp(prob, autodiff, fx_cache, NLBUtils.safe_vec(x), aᵢ) bᵢ = J_fact \ NLBUtils.safe_vec(hvvp) cᵢ_ = NLBUtils.safe_vec(cᵢ) diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 1090ac5f1..99f1dcd3a 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -166,12 +166,13 @@ function compute_hvvp(prob, autodiff, fx, x, dir) jvp_fn = if SciMLBase.isinplace(prob) @closure (u, p) -> begin du = NLBUtils.safe_similar(fx, promote_type(eltype(fx), eltype(u))) - return only(DI.pushforward(prob.f, du, autodiff, u, (dir,), Constant(p))) + return only(DI.pushforward( + prob.f, NLBUtils.safe_vec(du), autodiff, u, (dir,), Constant(p))) end else @closure (u, p) -> only(DI.pushforward(prob.f, autodiff, u, (dir,), Constant(p))) end - only(DI.pushforward(jvp_fn, autodiff, x, (dir,), Constant(prob.p))) + only(DI.pushforward(jvp_fn, autodiff, x, (dir,), Constant(NLBUtils.safe_vec(prob.p)))) end end From b66b2cdcfb2cd298830b10aa7e1da713da268642 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <52615090+ErikQQY@users.noreply.github.com> Date: Mon, 7 Apr 2025 03:35:15 +0800 Subject: [PATCH 2/2] Fix typos --- lib/SimpleNonlinearSolve/src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 99f1dcd3a..19173a9a5 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -172,7 +172,7 @@ function compute_hvvp(prob, autodiff, fx, x, dir) else @closure (u, p) -> only(DI.pushforward(prob.f, autodiff, u, (dir,), Constant(p))) end - only(DI.pushforward(jvp_fn, autodiff, x, (dir,), Constant(NLBUtils.safe_vec(prob.p)))) + only(DI.pushforward(jvp_fn, autodiff, x, (dir,), Constant(prob.p))) end end