Skip to content

Commit 672d894

Browse files
committed
fix stats
1 parent b0f957c commit 672d894

File tree

2 files changed

+24
-76
lines changed

2 files changed

+24
-76
lines changed

lib/OrdinaryDiffEqDifferentiation/src/derivative_wrappers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number},
239239
else
240240
forwarddiff_color_jacobian!(J, f, x, jac_config)
241241
end
242-
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
242+
OrdinaryDiffEqCore.increment_nf!(integrator.stats, maximum(jac_config.colorvec))
243243
elseif alg_autodiff(alg) isa AutoFiniteDiff
244244
isforward = alg_difftype(alg) === Val{:forward}
245245
if isforward

test/interface/stats_tests.jl

Lines changed: 23 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -5,84 +5,32 @@ function f(u, p, t)
55
x[] += 1
66
return 5 * u
77
end
8-
u0 = [1.0, 1.0]
9-
tspan = (0.0, 1.0)
10-
prob = ODEProblem(f, u0, tspan)
11-
12-
x[] = 0
13-
sol = solve(prob, Vern7())
14-
@test x[] == sol.stats.nf
15-
16-
x[] = 0
17-
sol = solve(prob, Vern8())
18-
@test x[] == sol.stats.nf
19-
20-
x[] = 0
21-
sol = solve(prob, Vern9())
22-
@test x[] == sol.stats.nf
23-
24-
x[] = 0
25-
sol = solve(prob, Tsit5())
26-
@test x[] == sol.stats.nf
27-
28-
x[] = 0
29-
sol = solve(prob, BS3())
30-
@test x[] == sol.stats.nf
31-
32-
x[] = 0
33-
sol = solve(prob, KenCarp4(; autodiff = true))
34-
@test x[] == sol.stats.nf
35-
36-
x[] = 0
37-
sol = solve(prob, KenCarp4(; autodiff = false, diff_type = Val{:forward}))
38-
@test x[] == sol.stats.nf
39-
40-
x[] = 0
41-
sol = solve(prob, KenCarp4(; autodiff = false, diff_type = Val{:central}))
42-
@test x[] == sol.stats.nf
43-
44-
x[] = 0
45-
sol = solve(prob, KenCarp4(; autodiff = false, diff_type = Val{:complex}))
46-
@test x[] == sol.stats.nf
47-
48-
x[] = 0
49-
sol = solve(prob, Rosenbrock23(; autodiff = true))
50-
@test x[] == sol.stats.nf
51-
52-
x[] = 0
53-
sol = solve(prob, Rosenbrock23(; autodiff = false, diff_type = Val{:forward}))
54-
@test x[] == sol.stats.nf
55-
56-
x[] = 0
57-
sol = solve(prob, Rosenbrock23(; autodiff = false, diff_type = Val{:central}))
58-
@test x[] == sol.stats.nf
59-
60-
x[] = 0
61-
sol = solve(prob, Rosenbrock23(; autodiff = false, diff_type = Val{:complex}))
62-
@test x[] == sol.stats.nf
63-
64-
x[] = 0
65-
sol = solve(prob, Rodas5(; autodiff = true))
66-
@test x[] == sol.stats.nf
67-
68-
x[] = 0
69-
sol = solve(prob, Rodas5(; autodiff = false, diff_type = Val{:forward}))
70-
@test x[] == sol.stats.nf
71-
72-
x[] = 0
73-
sol = solve(prob, Rodas5(; autodiff = false, diff_type = Val{:central}))
74-
@test x[] == sol.stats.nf
75-
76-
x[] = 0
77-
sol = solve(prob, Rodas5(; autodiff = false, diff_type = Val{:complex}))
78-
@test x[] == sol.stats.nf
79-
808
function g(du, u, p, t)
819
x[] += 1
8210
@. du = 5 * u
8311
end
12+
13+
u0 = [1.0, 1.0]
14+
tspan = (0.0, 1.0)
15+
probop = ODEProblem(f, u0, tspan)
8416
probip = ODEProblem(g, u0, tspan)
8517

86-
x[] = 0
87-
sol = solve(probip, ROCK4())
88-
@test x[] == sol.stats.nf
18+
@testset "stats_tests" begin
19+
@testset "$prob" for prob in [probop, probip]
20+
@testset "$alg" for alg in [BS3, Tsit5, Vern7, Vern9, ROCK4]
21+
x[] = 0
22+
sol = solve(prob, alg())
23+
@test x[] == sol.stats.nf
24+
end
25+
@testset "$alg" for alg in [Rodas5P, KenCarp4]
26+
@testset "$kwargs" for kwargs in [(autodiff = true,),
27+
(autodiff = false, diff_type = Val{:forward}),
28+
(autodiff = false, diff_type = Val{:central}),
29+
(autodiff = false, diff_type = Val{:complex}),]
30+
x[] = 0
31+
sol = solve(prob, alg(;kwargs...))
32+
@test x[] == sol.stats.nf
33+
end
34+
end
35+
end
36+
end

0 commit comments

Comments
 (0)