Skip to content

Commit b403cef

Browse files
Merge pull request #1570 from AayushSabharwal/as/fast-sub-sparse
feat: implement `fast_substitute` and `fixpoint_sub` for `SparseMatrixCSC`
2 parents 8d5fa06 + 7634a49 commit b403cef

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/variable.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,12 @@ function fixpoint_sub(x, dict; operator = Nothing, maxiters = 1000)
589589

590590
return x
591591
end
592+
function fixpoint_sub(x::SparseMatrixCSC, dict; operator = Nothing, maxiters = 1000)
593+
I, J, V = findnz(x)
594+
V = fixpoint_sub(V, dict; operator, maxiters)
595+
m, n = size(x)
596+
return sparse(I, J, V, m, n)
597+
end
592598

593599
const Eq = Union{Equation, Inequality}
594600
"""
@@ -620,6 +626,12 @@ end
620626
function fast_substitute(eqs::AbstractArray, subs::Pair; operator = Nothing)
621627
fast_substitute.(eqs, (subs,); operator)
622628
end
629+
function fast_substitute(eqs::SparseMatrixCSC, subs; operator = Nothing)
630+
I, J, V = findnz(eqs)
631+
V = fast_substitute(V, subs; operator)
632+
m, n = size(eqs)
633+
return sparse(I, J, V, m, n)
634+
end
623635
for (exprType, subsType) in Iterators.product((Num, Symbolics.Arr), (Any, Pair))
624636
@eval function fast_substitute(expr::$exprType, subs::$subsType; operator = Nothing)
625637
fast_substitute(value(expr), subs; operator)

test/utils.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import Symbolics: symbolic_to_float, var_from_nested_derivative, unwrap,
33
isblock, flatten_expr!, build_expr, get_variables,
44
is_singleton, diff2term, tosymbol, lower_varname,
55
makesubscripts, degree, coeff
6-
6+
using SparseArrays
77
using Test
88

99
@testset "get_variables" begin
@@ -188,6 +188,17 @@ end
188188
@test isequal(Symbolics.fast_substitute(x[1], unwrap(x) => collect(unwrap(x))), x[1])
189189
end
190190

191+
@testset "`fixpoint_sub` and `fast_substitute` on sparse arrays" begin
192+
@variables x y z
193+
mat = Num[x 0 0; 0 y 0; 0 0 z]
194+
mat = sparse(mat)
195+
mat = unwrap.(mat)
196+
rules = Dict(x => y, y => z, z => 1)
197+
res = Symbolics.fixpoint_sub(mat, rules)
198+
@test res isa SparseMatrixCSC
199+
@test res[1, 1] == res[2, 2] == res[3, 3] == 1
200+
end
201+
191202
@testset "numerator and denominator" begin
192203
@variables x y
193204
num_den(x) = (numerator(x), denominator(x))

0 commit comments

Comments
 (0)