Skip to content

Commit 1d05b16

Browse files
Merge pull request #2486 from AayushSabharwal/as/symbolic-save-idxs
feat: support passing symbolic variables to `save_idxs`
2 parents c87f6f2 + 9e45804 commit 1d05b16

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

lib/OrdinaryDiffEqCore/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
3434
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
3535
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
3636
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
37+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3738
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
3839

3940
[weakdeps]
@@ -69,13 +70,14 @@ Random = "<0.0.1, 1"
6970
RecursiveArrayTools = "2.36, 3"
7071
Reexport = "1.0"
7172
SafeTestsets = "0.1.0"
72-
SciMLBase = "2.56"
73+
SciMLBase = "2.57.2"
7374
SciMLOperators = "0.3"
7475
SciMLStructures = "1"
7576
SimpleUnPack = "1"
7677
Static = "0.8, 1"
7778
StaticArrayInterface = "1.2"
7879
StaticArraysCore = "1.0"
80+
SymbolicIndexingInterface = "0.3.31"
7981
Test = "<0.0.1, 1"
8082
TruncatedStacktraces = "1.2"
8183
julia = "1.10"

lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ import Accessors: @reset
7676

7777
using SciMLStructures: canonicalize, Tunable, isscimlstructure
7878

79+
using SymbolicIndexingInterface: parameter_values, is_variable, variable_index, symbolic_type, NotSymbolic
80+
7981
const CompiledFloats = Union{Float32, Float64}
8082
import Preferences
8183

lib/OrdinaryDiffEqCore/src/solve.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,30 @@ function DiffEqBase.__init(
264264
end
265265

266266
### Algorithm-specific defaults ###
267+
if save_idxs === nothing
268+
saved_subsystem = nothing
269+
else
270+
if !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
271+
_save_idxs = [save_idxs]
272+
else
273+
_save_idxs = save_idxs
274+
end
275+
saved_subsystem = SciMLBase.SavedSubsystem(prob, parameter_values(prob), _save_idxs)
276+
if saved_subsystem !== nothing
277+
_save_idxs = SciMLBase.get_saved_state_idxs(saved_subsystem)
278+
if isempty(_save_idxs)
279+
# no states to save
280+
save_idxs = Int[]
281+
elseif !(save_idxs isa AbstractArray) || symbolic_type(save_idxs) != NotSymbolic()
282+
# only a single state to save, and save it as a scalar timeseries instead of
283+
# single-element array
284+
save_idxs = only(_save_idxs)
285+
else
286+
save_idxs = _save_idxs
287+
end
288+
end
289+
end
290+
267291
if save_idxs === nothing
268292
ksEltype = Vector{rateType}
269293
else
@@ -427,7 +451,7 @@ function DiffEqBase.__init(
427451
f, timeseries, ts, ks, alg_choice, dense, cache, differential_vars, false)
428452
sol = DiffEqBase.build_solution(prob, _alg, ts, timeseries,
429453
dense = dense, k = ks, interp = id, alg_choice = alg_choice,
430-
calculate_error = false, stats = stats)
454+
calculate_error = false, stats = stats, saved_subsystem = saved_subsystem)
431455

432456
if recompile_flag == true
433457
FType = typeof(f)

0 commit comments

Comments
 (0)