Skip to content

Commit 2438432

Browse files
feat: allow adding constant and nonnumeric parameters to IndexCache after construction
1 parent 16c1b6a commit 2438432

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

src/systems/index_cache.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ struct IndexCache
6363
symbol_to_variable::Dict{Symbol, SymbolicParam}
6464
end
6565

66+
function Base.copy(ic::IndexCache)
67+
IndexCache(copy(ic.unknown_idx), copy(ic.discrete_idx), copy(ic.callback_to_clocks),
68+
copy(ic.tunable_idx), copy(ic.initials_idx), copy(ic.constant_idx),
69+
copy(ic.nonnumeric_idx), copy(ic.observed_syms_to_timeseries),
70+
copy(ic.dependent_pars_to_timeseries), copy(ic.discrete_buffer_sizes),
71+
ic.tunable_buffer_size, ic.initials_buffer_size,
72+
copy(ic.constant_buffer_sizes), copy(ic.nonnumeric_buffer_sizes),
73+
copy(ic.symbol_to_variable))
74+
end
75+
6676
function IndexCache(sys::AbstractSystem)
6777
unks = unknowns(sys)
6878
unk_idxs = UnknownIndexMap()
@@ -718,3 +728,55 @@ function subset_unknowns_observed(
718728
@set! ic.observed_syms_to_timeseries = observed_syms_to_timeseries
719729
return ic
720730
end
731+
732+
function with_additional_constant_parameter(sys::AbstractSystem, par)
733+
par = unwrap(par)
734+
ps = copy(get_ps(sys))
735+
push!(ps, par)
736+
@set! sys.ps = ps
737+
is_split(sys) || return sys
738+
739+
ic = copy(get_index_cache(sys))
740+
T = symtype(par)
741+
bufidx = findfirst(buft -> buft.type == T, ic.constant_buffer_sizes)
742+
if bufidx === nothing
743+
push!(ic.constant_buffer_sizes, BufferTemplate(T, 1))
744+
bufidx = length(ic.constant_buffer_sizes)
745+
idx_in_buf = 1
746+
else
747+
buft = ic.constant_buffer_sizes[bufidx]
748+
ic.constant_buffer_sizes[bufidx] = BufferTemplate(T, buft.length + 1)
749+
idx_in_buf = buft.length + 1
750+
end
751+
752+
ic.constant_idx[par] = ic.constant_idx[renamespace(sys, par)] = (bufidx, idx_in_buf)
753+
@set! sys.index_cache = ic
754+
755+
return sys
756+
end
757+
758+
function with_additional_nonnumeric_parameter(sys::AbstractSystem, par)
759+
par = unwrap(par)
760+
ps = copy(get_ps(sys))
761+
push!(ps, par)
762+
@set! sys.ps = ps
763+
is_split(sys) || return sys
764+
765+
ic = copy(get_index_cache(sys))
766+
T = symtype(par)
767+
bufidx = findfirst(buft -> buft.type == T, ic.nonnumeric_buffer_sizes)
768+
if bufidx === nothing
769+
push!(ic.nonnumeric_buffer_sizes, BufferTemplate(T, 1))
770+
bufidx = length(ic.nonnumeric_buffer_sizes)
771+
idx_in_buf = 1
772+
else
773+
buft = ic.nonnumeric_buffer_sizes[bufidx]
774+
ic.nonnumeric_buffer_sizes[bufidx] = BufferTemplate(T, buft.length + 1)
775+
idx_in_buf = buft.length + 1
776+
end
777+
778+
ic.nonnumeric_idx[par] = ic.nonnumeric_idx[renamespace(sys, par)] = (bufidx, idx_in_buf)
779+
@set! sys.index_cache = ic
780+
781+
return sys
782+
end

0 commit comments

Comments
 (0)