Skip to content

Commit f5d3f48

Browse files
Merge pull request #3740 from AayushSabharwal/as/fix-affect-unpre-callable
fix: fix non-numeric parameters in callbacks, `distribute_shift` for called parameters
2 parents be9b130 + 84dbc1c commit f5d3f48

File tree

9 files changed

+59
-11
lines changed

9 files changed

+59
-11
lines changed

src/problems/daeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ end
7272
eval_module, check_compatibility, implicit_dae = true, expression, kwargs...)
7373

7474
kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
75-
kwargs...)
75+
op, kwargs...)
7676

7777
diffvars = collect_differential_variables(sys)
7878
sts = unknowns(sys)

src/problems/ddeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666
end
6767

6868
kwargs = process_kwargs(
69-
sys; expression, callback, eval_expression, eval_module, kwargs...)
69+
sys; expression, callback, eval_expression, eval_module, op, kwargs...)
7070
args = (; f, u0, h, tspan, p)
7171

7272
return maybe_codegen_scimlproblem(expression, DDEProblem{iip}, args; kwargs...)

src/problems/jumpproblem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@
8080
end
8181

8282
# handle events, making sure to reset aggregators in the generated affect functions
83-
cbs = process_events(sys; callback, eval_expression, eval_module, reset_jumps = true)
83+
cbs = process_events(
84+
sys; callback, eval_expression, eval_module, op, reset_jumps = true)
8485

8586
if rng !== nothing
8687
kwargs = (; kwargs..., rng)

src/problems/odeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575
eval_module, expression, check_compatibility, kwargs...)
7676

7777
kwargs = process_kwargs(
78-
sys; expression, callback, eval_expression, eval_module, kwargs...)
78+
sys; expression, callback, eval_expression, eval_module, op, kwargs...)
7979

8080
ptype = getmetadata(sys, ProblemTypeCtx, StandardODEProblem())
8181
args = (; f, u0, tspan, p, ptype)

src/problems/sddeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end
6868
end
6969

7070
noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
71-
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, kwargs...)
71+
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, op, kwargs...)
7272

7373
if expression == Val{true}
7474
g = :(f.g)

src/problems/sdeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878

7979
noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
8080
kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
81-
kwargs...)
81+
op, kwargs...)
8282

8383
args = (; f, u0, tspan, p)
8484
kwargs = (; noise, noise_rate_prototype, kwargs...)

src/structural_transformation/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,8 @@ function _distribute_shift(expr, shift)
596596
(op isa Union{Pre, Initial, Sample, Hold}) && return expr
597597
args = arguments(expr)
598598

599-
if ModelingToolkit.isvariable(expr) && operation(expr) !== getindex
599+
if ModelingToolkit.isvariable(expr) && operation(expr) !== getindex &&
600+
!ModelingToolkit.iscalledparameter(expr)
600601
(length(args) == 1 && isequal(shift.t, only(args))) ? (return shift(expr)) :
601602
(return expr)
602603
elseif op isa Shift

src/systems/callbacks.jl

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ Base.show(io::IO, x::Pre) = print(io, "Pre")
6363
input_timedomain(::Pre, _ = nothing) = ContinuousClock()
6464
output_timedomain(::Pre, _ = nothing) = ContinuousClock()
6565
unPre(x::Num) = unPre(unwrap(x))
66-
unPre(x::BasicSymbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x
66+
unPre(x::Symbolics.Arr) = unPre(unwrap(x))
67+
unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x
6768

6869
function (p::Pre)(x)
6970
iw = Symbolics.iswrapped(x)
@@ -797,16 +798,34 @@ function add_integrator_header(
797798
expr.body)
798799
end
799800

801+
function default_operating_point(affsys::AffectSystem)
802+
sys = system(affsys)
803+
804+
op = Dict(unknowns(sys) .=> 0.0)
805+
for p in parameters(sys)
806+
T = symtype(p)
807+
if T <: Number
808+
op[p] = false
809+
elseif T <: Array{<:Real} && is_sized_array_symbolic(p)
810+
op[p] = zeros(size(p))
811+
end
812+
end
813+
return op
814+
end
815+
800816
"""
801817
Compile an affect defined by a set of equations. Systems with algebraic equations will solve implicit discrete problems to obtain their next state. Systems without will generate functions that perform explicit updates.
802818
"""
803819
function compile_equational_affect(
804820
aff::Union{AffectSystem, Vector{Equation}}, sys; reset_jumps = false,
805-
eval_expression = false, eval_module = @__MODULE__, kwargs...)
821+
eval_expression = false, eval_module = @__MODULE__, op = nothing, kwargs...)
806822
if aff isa AbstractVector
807823
aff = make_affect(
808824
aff; iv = get_iv(sys), warn_no_algebraic = false)
809825
end
826+
if op === nothing
827+
op = default_operating_point(aff)
828+
end
810829
affsys = system(aff)
811830
ps_to_update = discretes(aff)
812831
dvs_to_update = setdiff(unknowns(aff), getfield.(observed(sys), :lhs))
@@ -871,10 +890,10 @@ function compile_equational_affect(
871890
p_getter = getsym(affsys, ps_to_update)
872891

873892
affprob = ImplicitDiscreteProblem(
874-
affsys, Pair[unknowns(affsys) .=> 0; parameters(affsys) .=> 0],
893+
affsys, op,
875894
(0, 0);
876895
build_initializeprob = false, check_length = false, eval_expression,
877-
eval_module, check_compatibility = false)
896+
eval_module, check_compatibility = false, kwargs...)
878897

879898
function implicit_affect!(integ)
880899
new_u0 = affu_getter(integ)

test/symbolic_events.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,30 @@ end
13481348
@test SciMLBase.successful_retcode(sol)
13491349
@test sol[inner.p][end] 1.0
13501350
end
1351+
1352+
mutable struct ParamTest
1353+
y::Any
1354+
end
1355+
1356+
@testset "callable parameter and symbolic affect" begin
1357+
(pt::ParamTest)(x) = pt.y - x
1358+
1359+
p1 = ParamTest(1)
1360+
tp1 = typeof(p1)
1361+
@parameters (p_1::tp1)(..) = p1
1362+
@parameters p2(t) = 1.0
1363+
@variables x(t) = 0.0
1364+
@variables x2(t)
1365+
event = [0.5] => [p2 ~ Pre(t)]
1366+
1367+
eq = [
1368+
D(x) ~ p2,
1369+
x2 ~ p_1(x)
1370+
]
1371+
@mtkcompile sys = ODESystem(eq, t, [x, x2], [p_1, p2], discrete_events = [event])
1372+
1373+
prob = ODEProblem(sys, [], (0.0, 1.0))
1374+
sol = solve(prob)
1375+
@test SciMLBase.successful_retcode(sol)
1376+
@test sol[x, end]1.0 atol=1e-6
1377+
end

0 commit comments

Comments
 (0)