Skip to content

Commit d39f830

Browse files
refactor: use Symbolics.semilinear_form for LinearProblem codegen
1 parent 3503046 commit d39f830

File tree

1 file changed

+6
-26
lines changed

1 file changed

+6
-26
lines changed

src/systems/codegen.jl

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,35 +1142,15 @@ Return matrix `A` and vector `b` such that the system `sys` can be represented a
11421142
- `sparse`: return a sparse `A`.
11431143
"""
11441144
function calculate_A_b(sys::System; sparse = false)
1145-
rhss = [eq.rhs for eq in full_equations(sys)]
1145+
rhss = [-eq.rhs for eq in full_equations(sys)]
11461146
dvs = unknowns(sys)
11471147

1148-
A = Matrix{Any}(undef, length(rhss), length(dvs))
1149-
b = Vector{Any}(undef, length(rhss))
1150-
for (i, rhs) in enumerate(rhss)
1151-
# mtkcompile makes this `0 ~ rhs` which typically ends up giving
1152-
# unknowns negative coefficients. If given the equations `A * x ~ b`
1153-
# it will simplify to `0 ~ b - A * x`. Thus this negation usually leads
1154-
# to more comprehensible user API.
1155-
resid = -rhs
1156-
for (j, var) in enumerate(dvs)
1157-
p, q, islinear = Symbolics.linear_expansion(resid, var)
1158-
if !islinear
1159-
throw(ArgumentError("System is not linear. Equation $((0 ~ rhs)) is not linear in unknown $var."))
1160-
end
1161-
A[i, j] = p
1162-
resid = q
1163-
end
1164-
# negate beucause `resid` is the residual on the LHS
1165-
b[i] = -resid
1166-
end
1167-
1168-
@assert all(Base.Fix1(isassigned, A), eachindex(A))
1169-
@assert all(Base.Fix1(isassigned, A), eachindex(b))
1170-
1171-
if sparse
1172-
A = SparseArrays.sparse(A)
1148+
A, b = semilinear_form(rhss, dvs)
1149+
if !sparse
1150+
A = collect(A)
11731151
end
1152+
A = unwrap.(A)
1153+
b = unwrap.(-b)
11741154
return A, b
11751155
end
11761156

0 commit comments

Comments
 (0)