Skip to content

Commit ad58d55

Browse files
feat: add LinearFunction as a late-binding for creating LinearProblem
1 parent 3affea1 commit ad58d55

File tree

1 file changed

+65
-33
lines changed

1 file changed

+65
-33
lines changed

src/problems/linearproblem.jl

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,42 @@
1+
struct LinearFunction{iip, I} <: SciMLBase.AbstractSciMLFunction{iip}
2+
interface::I
3+
A::AbstractMatrix
4+
b::AbstractVector
5+
end
6+
7+
function LinearFunction{iip}(
8+
sys::System; expression = Val{false}, check_compatibility = true,
9+
sparse = false, eval_expression = false, eval_module = @__MODULE__,
10+
checkbounds = false, cse = true, kwargs...) where {iip}
11+
check_complete(sys, LinearProblem)
12+
check_compatibility && check_compatible_system(LinearProblem, sys)
13+
14+
A, b = calculate_A_b(sys; sparse)
15+
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
16+
eval_module, checkbounds, cse, kwargs...)
17+
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
18+
eval_module, checkbounds, cse, kwargs...)
19+
observedfun = ObservedFunctionCache(
20+
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
21+
cse)
22+
23+
if expression == Val{true}
24+
symbolic_interface = quote
25+
update_A = $update_A
26+
update_b = $update_b
27+
sys = $sys
28+
observedfun = $observedfun
29+
$(SciMLBase.SymbolicLinearInterface)(
30+
update_A, update_b, sys, observedfun, nothing)
31+
end
32+
else
33+
symbolic_interface = SciMLBase.SymbolicLinearInterface(
34+
update_A, update_b, sys, observedfun, nothing)
35+
end
36+
37+
return LinearFunction{iip, typeof(symbolic_interface)}(symbolic_interface, A, b)
38+
end
39+
140
function SciMLBase.LinearProblem(sys::System, op; kwargs...)
241
SciMLBase.LinearProblem{true}(sys, op; kwargs...)
342
end
@@ -9,14 +48,14 @@ end
948
function SciMLBase.LinearProblem{iip}(
1049
sys::System, op; check_length = true, expression = Val{false},
1150
check_compatibility = true, sparse = false, eval_expression = false,
12-
eval_module = @__MODULE__, checkbounds = false, cse = true,
13-
u0_constructor = identity, u0_eltype = nothing, kwargs...) where {iip}
51+
eval_module = @__MODULE__, u0_constructor = identity, u0_eltype = nothing,
52+
kwargs...) where {iip}
1453
check_complete(sys, LinearProblem)
1554
check_compatibility && check_compatible_system(LinearProblem, sys)
1655

17-
_, u0,
56+
f, u0,
1857
p = process_SciMLProblem(
19-
EmptySciMLFunction{iip}, sys, op; check_length, expression,
58+
LinearFunction{iip}, sys, op; check_length, expression,
2059
build_initializeprob = false, symbolic_u0 = true, u0_constructor, u0_eltype,
2160
kwargs...)
2261

@@ -33,45 +72,38 @@ function SciMLBase.LinearProblem{iip}(
3372
u0_eltype = something(u0_eltype, floatT)
3473

3574
u0_constructor = get_p_constructor(u0_constructor, u0Type, u0_eltype)
75+
symbolic_interface = f.interface
76+
A,
77+
b = get_A_b_from_LinearFunction(
78+
sys, f, p; eval_expression, eval_module, expression, u0_constructor, sparse)
3679

37-
A, b = calculate_A_b(sys; sparse)
38-
update_A = generate_update_A(sys, A; expression, wrap_gfw = Val{true}, eval_expression,
39-
eval_module, checkbounds, cse, kwargs...)
40-
update_b = generate_update_b(sys, b; expression, wrap_gfw = Val{true}, eval_expression,
41-
eval_module, checkbounds, cse, kwargs...)
42-
observedfun = ObservedFunctionCache(
43-
sys; steady_state = false, expression, eval_expression, eval_module, checkbounds,
44-
cse)
80+
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
81+
args = (; A, b, p)
4582

83+
return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
84+
end
85+
86+
function get_A_b_from_LinearFunction(
87+
sys::System, f::LinearFunction, p; eval_expression = false,
88+
eval_module = @__MODULE__, expression = Val{false}, u0_constructor = identity,
89+
u0_eltype = float, sparse = false)
90+
@unpack A, b, interface = f
4691
if expression == Val{true}
47-
symbolic_interface = quote
48-
update_A = $update_A
49-
update_b = $update_b
50-
sys = $sys
51-
observedfun = $observedfun
52-
$(SciMLBase.SymbolicLinearInterface)(
53-
update_A, update_b, sys, observedfun, nothing)
54-
end
5592
get_A = build_explicit_observed_function(
5693
sys, A; param_only = true, eval_expression, eval_module)
57-
if sparse
58-
get_A = SparseArrays.sparse get_A
59-
end
6094
get_b = build_explicit_observed_function(
6195
sys, b; param_only = true, eval_expression, eval_module)
62-
A = u0_constructor(get_A(p))
63-
b = u0_constructor(get_b(p))
96+
A = u0_constructor(u0_eltype.(get_A(p)))
97+
b = u0_constructor(u0_eltype.(get_b(p)))
6498
else
65-
symbolic_interface = SciMLBase.SymbolicLinearInterface(
66-
update_A, update_b, sys, observedfun, nothing)
67-
A = u0_constructor(update_A(p))
68-
b = u0_constructor(update_b(p))
99+
A = u0_constructor(u0_eltype.(interface.update_A!(p)))
100+
b = u0_constructor(u0_eltype.(interface.update_b!(p)))
101+
end
102+
if sparse
103+
A = SparseArrays.sparse(A)
69104
end
70105

71-
kwargs = (; u0, process_kwargs(sys; kwargs...)..., f = symbolic_interface)
72-
args = (; A, b, p)
73-
74-
return maybe_codegen_scimlproblem(expression, LinearProblem{iip}, args; kwargs...)
106+
return A, b
75107
end
76108

77109
# For remake

0 commit comments

Comments
 (0)