Skip to content

Commit 6a0d6c3

Browse files
Merge pull request #2470 from SciML/fsal_reset_mutable
Check mutable in fsal reset without requiring allocated caches
2 parents a50f363 + f834aa6 commit 6a0d6c3

File tree

5 files changed

+39
-13
lines changed

5 files changed

+39
-13
lines changed

.github/workflows/CI.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ jobs:
8282
Pkg.develop(map(path ->Pkg.PackageSpec.(;path="$(@__DIR__)/lib/$(path)"), readdir("./lib")));
8383
'
8484
- uses: julia-actions/julia-runtest@v1
85+
with:
86+
coverage: false
87+
check_bounds: auto
8588
env:
8689
GROUP: ${{ matrix.group }}
8790
- uses: julia-actions/julia-processcoverage@v1

lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ abstract type OrdinaryDiffEqMutableCache <: OrdinaryDiffEqCache end
44
struct ODEEmptyCache <: OrdinaryDiffEqConstantCache end
55
struct ODEChunkCache{CS} <: OrdinaryDiffEqConstantCache end
66

7+
ismutablecache(cache::OrdinaryDiffEqMutableCache) = true
8+
ismutablecache(cache::OrdinaryDiffEqConstantCache) = false
9+
710
# Don't worry about the potential alloc on a constant cache
811
get_fsalfirstlast(cache::OrdinaryDiffEqConstantCache, u) = (zero(u), zero(u))
912

@@ -13,6 +16,10 @@ mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache
1316
current::Int
1417
end
1518

19+
function ismutablecache(cache::CompositeCache{T, F}) where {T, F}
20+
eltype(T) <: OrdinaryDiffEqMutableCache
21+
end
22+
1623
function get_fsalfirstlast(cache::CompositeCache, u)
1724
_x = get_fsalfirstlast(cache.caches[1], u)
1825
if first(_x) !== nothing
@@ -44,6 +51,10 @@ function get_fsalfirstlast(cache::DefaultCache, u)
4451
(cache.u, cache.u) # will be overwritten by the cache choice
4552
end
4653

54+
function ismutablecache(cache::DefaultCache{T1, T2, T3, T4, T5, T6, A, F, uType}) where {T1, T2, T3, T4, T5, T6, A, F, uType}
55+
T1 <: OrdinaryDiffEqMutableCache
56+
end
57+
4758
function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
4859
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
4960
dt, reltol, p, calck,

lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -482,11 +482,7 @@ function reset_fsal!(integrator)
482482
# Ignore DAEs but they already re-ran initialization
483483
# Mass matrix DAEs do need to reset FSAL if available
484484
if !(integrator.sol.prob isa DAEProblem)
485-
if integrator.cache isa OrdinaryDiffEqMutableCache ||
486-
(integrator.cache isa CompositeCache &&
487-
integrator.cache.caches[1] isa OrdinaryDiffEqMutableCache) ||
488-
(integrator.cache isa DefaultCache &&
489-
integrator.cache.cache1 isa OrdinaryDiffEqMutableCache)
485+
if ismutablecache(integrator.cache)
490486
integrator.f(integrator.fsalfirst, integrator.u, integrator.p, integrator.t)
491487
else
492488
integrator.fsalfirst = integrator.f(integrator.u, integrator.p, integrator.t)

test/integrators/callback_allocation_tests.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,17 @@ cbs = CallbackSet(ContinuousCallback(cond_1, cb_affect!),
3232
ContinuousCallback(cond_9, cb_affect!))
3333

3434
integrator = init(
35-
ODEProblem(f!, [0.8, 1.0], (0.0, 100.0), [0, 0]), Tsit5(), callback = cbs,
35+
ODEProblem{true, SciMLBase.FullSpecialize}(f!, [0.8, 1.0],
36+
(0.0, 100.0), [0, 0]), Tsit5(), callback = cbs,
3637
save_on = false);
3738
# Force a callback event to occur so we can call handle_callbacks! directly.
3839
# Step to a point where u[1] is still > 0.5, so we can force it below 0.5 and
3940
# call handle callbacks
4041
step!(integrator, 0.1, true)
4142

42-
if VERSION >= v"1.7"
43-
function handle_allocs(integrator)
44-
integrator.u[1] = 0.4
45-
@allocations OrdinaryDiffEqCore.handle_callbacks!(integrator)
46-
end
47-
handle_allocs(integrator)
48-
@test handle_allocs(integrator) == 0
43+
function handle_allocs(integrator)
44+
integrator.u[1] = 0.4
45+
@allocations OrdinaryDiffEqCore.handle_callbacks!(integrator)
4946
end
47+
handle_allocs(integrator)
48+
@test_skip handle_allocs(integrator) == 0

test/interface/composite_algorithm_test.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,20 @@ sol = solve(prob,
8080
prob = remake(prob_ode_2Dlinear, u0 = rand(ComplexF64, 2, 2))
8181
sol = solve(prob, AutoTsit5(Rosenbrock23(autodiff = false))) # Complex and AD don't mix
8282
@test sol.retcode == ReturnCode.Success
83+
84+
# https://github.com/SciML/ModelingToolkit.jl/issues/3043
85+
function rober(du, u, p, t)
86+
y₁, y₂, y₃ = u
87+
k₁, k₂, k₃ = p
88+
du[1] = -k₁ * y₁ + k₃ * y₂ * y₃
89+
du[2] = k₁ * y₁ - k₃ * y₂ * y₃ - k₂ * y₂^2
90+
du[3] = y₁ + y₂ + y₃ - 1
91+
nothing
92+
end
93+
M = [1.0 0 0
94+
0 1.0 0
95+
0 0 0]
96+
f = ODEFunction(rober, mass_matrix = M)
97+
prob_mm = ODEProblem(f, [1.0, 0.0, 0.0], (0.0, 1e5), (0.04, 3e7, 1e4))
98+
cb = DiscreteCallback((u,t,integrator)->true, (integrator)->u_modified!(integrator,true))
99+
sol = solve(prob_mm, DefaultODEAlgorithm(), callback = cb)

0 commit comments

Comments
 (0)