Skip to content

Commit 425cc0f

Browse files
refactor: use Symbolics.semilinear_form for LinearProblem codegen
1 parent 0828cd9 commit 425cc0f

File tree

1 file changed

+6
-26
lines changed

1 file changed

+6
-26
lines changed

src/systems/codegen.jl

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,35 +1144,15 @@ Return matrix `A` and vector `b` such that the system `sys` can be represented a
11441144
- `sparse`: return a sparse `A`.
11451145
"""
11461146
function calculate_A_b(sys::System; sparse = false)
1147-
rhss = [eq.rhs for eq in full_equations(sys)]
1147+
rhss = [-eq.rhs for eq in full_equations(sys)]
11481148
dvs = unknowns(sys)
11491149

1150-
A = Matrix{Any}(undef, length(rhss), length(dvs))
1151-
b = Vector{Any}(undef, length(rhss))
1152-
for (i, rhs) in enumerate(rhss)
1153-
# mtkcompile makes this `0 ~ rhs` which typically ends up giving
1154-
# unknowns negative coefficients. If given the equations `A * x ~ b`
1155-
# it will simplify to `0 ~ b - A * x`. Thus this negation usually leads
1156-
# to more comprehensible user API.
1157-
resid = -rhs
1158-
for (j, var) in enumerate(dvs)
1159-
p, q, islinear = Symbolics.linear_expansion(resid, var)
1160-
if !islinear
1161-
throw(ArgumentError("System is not linear. Equation $((0 ~ rhs)) is not linear in unknown $var."))
1162-
end
1163-
A[i, j] = p
1164-
resid = q
1165-
end
1166-
# negate beucause `resid` is the residual on the LHS
1167-
b[i] = -resid
1168-
end
1169-
1170-
@assert all(Base.Fix1(isassigned, A), eachindex(A))
1171-
@assert all(Base.Fix1(isassigned, A), eachindex(b))
1172-
1173-
if sparse
1174-
A = SparseArrays.sparse(A)
1150+
A, b = semilinear_form(rhss, dvs)
1151+
if !sparse
1152+
A = collect(A)
11751153
end
1154+
A = unwrap.(A)
1155+
b = unwrap.(-b)
11761156
return A, b
11771157
end
11781158

0 commit comments

Comments
 (0)