Skip to content

Commit e2f1a4e

Browse files
Merge pull request #2489 from jClugstor/cache_strip_fix2
Fix cache_strip for default algs
2 parents df11489 + 52256a1 commit e2f1a4e

File tree

3 files changed

+27
-29
lines changed

3 files changed

+27
-29
lines changed

lib/OrdinaryDiffEqCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OrdinaryDiffEqCore"
22
uuid = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
33
authors = ["ParamThakkar123 <paramthakkar864@gmail.com>"]
4-
version = "1.7.0"
4+
version = "1.7.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

lib/OrdinaryDiffEqCore/src/interp_func.jl

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,14 @@ function SciMLBase.strip_interpolation(id::InterpolationData)
7575
end
7676

7777
function strip_cache(cache)
78-
if hasfield(typeof(cache), :jac_config)
79-
SciMLBase.@reset cache.jac_config = nothing
78+
if !(cache isa OrdinaryDiffEqCore.DefaultCache)
79+
cache = SciMLBase.constructorof(typeof(cache))([nothing
80+
for name in fieldnames(typeof(cache))]...)
81+
else
82+
# need to do something special for default cache
83+
cache = OrdinaryDiffEqCore.DefaultCache{Nothing, Nothing, Nothing, Nothing,
84+
Nothing, Nothing, Nothing, Nothing}(nothing, nothing, 0, nothing)
8085
end
81-
if hasfield(typeof(cache), :grad_config)
82-
SciMLBase.@reset cache.grad_config = nothing
83-
end
84-
if hasfield(typeof(cache), :nlsolver)
85-
SciMLBase.@reset cache.nlsolver = nothing
86-
end
87-
if hasfield(typeof(cache), :tf)
88-
SciMLBase.@reset cache.tf = nothing
89-
end
90-
if hasfield(typeof(cache), :uf)
91-
SciMLBase.@reset cache.uf = nothing
92-
end
93-
if hasfield(typeof(cache),:args)
94-
SciMLBase.@reset cache.args = nothing
95-
end
96-
86+
9787
cache
9888
end

test/interface/ode_strip_test.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,36 @@ prob = ODEProblem(lorenz!, u0, tspan)
1414
rosenbrock_sol = solve(prob, Rosenbrock23())
1515
TRBDF_sol = solve(prob, TRBDF2())
1616
vern_sol = solve(prob, Vern7())
17-
17+
default_sol = solve(prob)
1818
@testset "Interpolation Stripping" begin
1919
@test isnothing(SciMLBase.strip_interpolation(rosenbrock_sol.interp).f)
2020
@test isnothing(SciMLBase.strip_interpolation(rosenbrock_sol.interp).cache.jac_config)
2121
@test isnothing(SciMLBase.strip_interpolation(rosenbrock_sol.interp).cache.grad_config)
2222
end
2323

2424
@testset "Rosenbrock Solution Stripping" begin
25-
@test SciMLBase.strip_solution(rosenbrock_sol).prob isa NamedTuple
25+
stripped_sol = SciMLBase.strip_solution(rosenbrock_sol)
26+
@test stripped_sol.prob isa NamedTuple
2627
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol, strip_alg = true).alg)
27-
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.f)
28-
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.cache.jac_config)
29-
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.cache.grad_config)
30-
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.cache.uf)
31-
@test isnothing(SciMLBase.strip_solution(rosenbrock_sol).interp.cache.tf)
28+
@test isnothing(stripped_sol.interp.f)
29+
@test isnothing(stripped_sol.interp.cache.jac_config)
30+
@test isnothing(stripped_sol.interp.cache.grad_config)
31+
@test isnothing(stripped_sol.interp.cache.uf)
32+
@test isnothing(stripped_sol.interp.cache.tf)
3233
end
3334

3435
@testset "TRBDF Solution Stripping" begin
35-
@test SciMLBase.strip_solution(TRBDF_sol).prob isa NamedTuple
36+
stripped_sol = SciMLBase.strip_solution(TRBDF_sol)
37+
@test stripped_sol.prob isa NamedTuple
3638
@test isnothing(SciMLBase.strip_solution(TRBDF_sol, strip_alg = true).alg)
37-
@test isnothing(SciMLBase.strip_solution(TRBDF_sol).interp.f)
38-
@test isnothing(SciMLBase.strip_solution(TRBDF_sol).interp.cache.nlsolver)
39+
@test isnothing(stripped_sol.interp.f)
40+
@test isnothing(stripped_sol.interp.cache.nlsolver)
41+
end
42+
43+
@testset "Default Solution Stripping" begin
44+
stripped_sol = SciMLBase.strip_solution(default_sol)
45+
@test isnothing(stripped_sol.interp.cache.args)
46+
3947
end
4048

4149
@test_throws SciMLBase.LazyInterpolationException SciMLBase.strip_solution(vern_sol)

0 commit comments

Comments
 (0)