diff --git a/lib/OrdinaryDiffEqStabilizedRK/src/rkc_perform_step.jl b/lib/OrdinaryDiffEqStabilizedRK/src/rkc_perform_step.jl index 9bcdb59674..cfac627868 100644 --- a/lib/OrdinaryDiffEqStabilizedRK/src/rkc_perform_step.jl +++ b/lib/OrdinaryDiffEqStabilizedRK/src/rkc_perform_step.jl @@ -35,6 +35,7 @@ end μ, κ = recf[cache.start + (i - 2) * 2 + 1], recf[cache.start + (i - 2) * 2 + 2] ν = -1 - κ u = f(uᵢ₋₁, p, tᵢ₋₁) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) tᵢ₋₁ = dt * μ - ν * tᵢ₋₂ - κ * tᵢ₋₃ u = (dt * μ) * u - ν * uᵢ₋₁ - κ * uᵢ₋₂ i < cache.mdeg && (uᵢ₋₂ = uᵢ₋₁; @@ -110,6 +111,7 @@ end μ, κ = recf[ccache.start + (i - 2) * 2 + 1], recf[ccache.start + (i - 2) * 2 + 2] ν = -1 - κ f(k, uᵢ₋₁, p, tᵢ₋₁) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) tᵢ₋₁ = dt * μ - ν * tᵢ₋₂ - κ * tᵢ₋₃ @.. broadcast=false u=(dt * μ) * k - ν * uᵢ₋₁ - κ * uᵢ₋₂ if i < ccache.mdeg @@ -192,6 +194,7 @@ end μ, κ = recf[cache.start + (i - 2) * 2 + 1], recf[cache.start + (i - 2) * 2 + 2] ν = -1 - κ u = f(uᵢ₋₁, p, tᵢ₋₁) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) tᵢ₋₁ = dt * μ - ν * tᵢ₋₂ - κ * tᵢ₋₃ u = (dt * μ) * u - ν * uᵢ₋₁ - κ * uᵢ₋₂ i < cache.mdeg && (uᵢ₋₂ = uᵢ₋₁; @@ -314,6 +317,7 @@ end μ, κ = recf[ccache.start + (i - 2) * 2 + 1], recf[ccache.start + (i - 2) * 2 + 2] ν = -1 - κ f(k, uᵢ₋₁, p, tᵢ₋₁) + OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1) tᵢ₋₁ = (dt * μ) - ν * tᵢ₋₂ - κ * tᵢ₋₃ @.. broadcast=false u=(dt * μ) * k - ν * uᵢ₋₁ - κ * uᵢ₋₂ if i < ccache.mdeg diff --git a/lib/OrdinaryDiffEqStabilizedRK/test/rkc_tests.jl b/lib/OrdinaryDiffEqStabilizedRK/test/rkc_tests.jl index 3849513e88..beef67023c 100644 --- a/lib/OrdinaryDiffEqStabilizedRK/test/rkc_tests.jl +++ b/lib/OrdinaryDiffEqStabilizedRK/test/rkc_tests.jl @@ -72,3 +72,33 @@ end @test sim.𝒪est[:l∞]≈5 atol=testTol end end + +@testset "Number of function evaluations" begin + x = Ref(0) + u0 = [1.0, 1.0] + tspan = (0.0, 1.0) + probop = ODEProblem(u0, tspan) do u, p, t + x[] += 1 + return -500 * u + end + probip = ODEProblem(u0, tspan) do du, u, p, t + x[] += 1 + @. du = -500 * u + return nothing + end + + @testset "$prob" for prob in [probop, probip] + eigen_est = (integrator) -> integrator.eigen_est = 500 + algs = [ROCK2(), ROCK2(eigen_est = eigen_est), + ROCK4(), ROCK4(eigen_est = eigen_est), + RKC(), RKC(eigen_est = eigen_est), + SERK2(), SERK2(eigen_est = eigen_est), + ESERK4(), ESERK4(eigen_est = eigen_est), + ESERK5(), ESERK5(eigen_est = eigen_est)] + @testset "$alg" for alg in algs + x[] = 0 + sol = solve(prob, alg) + @test x[] == sol.stats.nf + end + end +end