Skip to content

Commit 5273028

Browse files
feat: add LinearFunction as a late-binding for creating LinearProblem
1 parent af1a950 commit 5273028

File tree

1 file changed

+57
-26
lines changed

1 file changed

+57
-26
lines changed

src/problems/linearproblem.jl

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,42 @@
1+
struct LinearFunction{iip, I} <: SciMLBase.AbstractSciMLFunction{iip}
2+
interface::I
3+
A::AbstractMatrix
4+
b::AbstractVector
5+
end
6+
7+
function LinearFunction{iip}(
8+
sys::System; expression = Val{false}, check_compatibility = true,
9+
sparse = false, eval_expression = false, eval_module = @__MODULE__,
10+
checkbounds = false, cse = true, kwargs...) where {iip}
11+
check_complete(sys, LinearProblem)
12+
check_compatibility && check_compatible_system(LinearProblem, sys)
13+
14+
A, b = calculate_A_b(sys; sparse)
15+
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
16+
eval_module, checkbounds, cse, kwargs...)
17+
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
18+
eval_module, checkbounds, cse, kwargs...)
19+
observedfun = ObservedFunctionCache(
20+
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
21+
cse)
22+
23+
if expression == Val{true}
24+
symbolic_interface = quote
25+
update_A = $update_A
26+
update_b = $update_b
27+
sys = $sys
28+
observedfun = $observedfun
29+
$(SciMLBase.SymbolicLinearInterface)(
30+
update_A, update_b, sys, observedfun, nothing)
31+
end
32+
else
33+
symbolic_interface = SciMLBase.SymbolicLinearInterface(
34+
update_A, update_b, sys, observedfun, nothing)
35+
end
36+
37+
return LinearFunction{iip, typeof(symbolic_interface)}(symbolic_interface, A, b)
38+
end
39+
140
function SciMLBase.LinearProblem(sys::System, op; kwargs...)
241
SciMLBase.LinearProblem{true}(sys, op; kwargs...)
342
end
@@ -14,9 +53,9 @@ function SciMLBase.LinearProblem{iip}(
1453
check_complete(sys, LinearProblem)
1554
check_compatibility && check_compatible_system(LinearProblem, sys)
1655

17-
_, u0,
56+
f, u0,
1857
p = process_SciMLProblem(
19-
EmptySciMLFunction{iip}, sys, op; check_length, expression,
58+
LinearFunction{iip}, sys, op; check_length, expression,
2059
build_initializeprob = false, symbolic_u0 = true, u0_constructor, u0_eltype,
2160
kwargs...)
2261

@@ -33,25 +72,22 @@ function SciMLBase.LinearProblem{iip}(
3372
u0_eltype = something(u0_eltype, floatT)
3473

3574
u0_constructor = get_p_constructor(u0_constructor, u0Type, u0_eltype)
75+
symbolic_interface = f.interface
76+
A,
77+
b = get_A_b_from_LinearFunction(
78+
sys, f, p; eval_expression, eval_module, expression, u0_constructor)
3679

37-
A, b = calculate_A_b(sys; sparse)
38-
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
39-
eval_module, checkbounds, cse, kwargs...)
40-
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
41-
eval_module, checkbounds, cse, kwargs...)
42-
observedfun = ObservedFunctionCache(
43-
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
44-
cse)
80+
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
81+
args = (; A, b, p)
4582

83+
return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
84+
end
85+
86+
function get_A_b_from_LinearFunction(
87+
sys::System, f::LinearFunction, p; eval_expression = false,
88+
eval_module = @__MODULE__, expression = Val{false}, u0_constructor = identity)
89+
@unpack A, b, interface = f
4690
if expression == Val{true}
47-
symbolic_interface = quote
48-
update_A = $update_A
49-
update_b = $update_b
50-
sys = $sys
51-
observedfun = $observedfun
52-
$(SciMLBase.SymbolicLinearInterface)(
53-
update_A, update_b, sys, observedfun, nothing)
54-
end
5591
get_A = build_explicit_observed_function(
5692
sys, A; param_only = true, eval_expression, eval_module)
5793
if sparse
@@ -62,16 +98,11 @@ function SciMLBase.LinearProblem{iip}(
6298
A = u0_constructor(get_A(p))
6399
b = u0_constructor(get_b(p))
64100
else
65-
symbolic_interface = SciMLBase.SymbolicLinearInterface(
66-
update_A, update_b, sys, observedfun, nothing)
67-
A = u0_constructor(update_A(p))
68-
b = u0_constructor(update_b(p))
101+
A = u0_constructor(interface.update_A!(p))
102+
b = u0_constructor(interface.update_b!(p))
69103
end
70104

71-
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
72-
args = (; A, b, p)
73-
74-
return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
105+
return A, b
75106
end
76107

77108
# For remake

0 commit comments

Comments
 (0)