Skip to content

Commit d4d2d9a

Browse files
Merge pull request #688 from JuliaSymbolics/b/cse-dag
Optimize CSE: Transition to DAG Representation with Hash Consing for Faster Equality Checks
2 parents 4b0d24e + 9c0c27c commit d4d2d9a

File tree

2 files changed

+105
-45
lines changed

2 files changed

+105
-45
lines changed

src/code.jl

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module Code
22

3-
using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions
3+
using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions,
4+
DocStringExtensions
45

56
export toexpr, Assignment, (), Let, Func, DestructuredArgs, LiteralExpr,
67
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
@@ -696,6 +697,52 @@ end
696697

697698
@inline newsym(::Type{T}) where T = Sym{T}(gensym("cse"))
698699

700+
"""
701+
$(SIGNATURES)
702+
703+
Perform a topological sort on a symbolic expression represented as a Directed Acyclic
704+
Graph (DAG).
705+
706+
This function takes a symbolic expression `graph` (potentially containing shared common
707+
sub-expressions) and returns an array of `Assignment` objects. Each `Assignment`
708+
represents a node in the sorted order, assigning a fresh symbol to its corresponding
709+
expression. The order ensures that all dependencies of a node appear before the node itself
710+
in the array.
711+
712+
Hash consing is assumed, meaning that structurally identical expressions are represented by
713+
the same object in memory. This allows for efficient equality checks using `IdDict`.
714+
"""
715+
function topological_sort(graph)
716+
sorted_nodes = Assignment[]
717+
visited = IdDict()
718+
719+
function dfs(node)
720+
if haskey(visited, node)
721+
return visited[node]
722+
end
723+
if iscall(node)
724+
args = map(dfs, arguments(node))
725+
new_node = maketerm(typeof(node), operation(node), args, metadata(node))
726+
sym = newsym(symtype(new_node))
727+
push!(sorted_nodes, sym new_node)
728+
visited[node] = sym
729+
return sym
730+
elseif _is_array_of_symbolics(node)
731+
new_node = map(dfs, node)
732+
sym = newsym(typeof(new_node))
733+
push!(sorted_nodes, sym new_node)
734+
visited[node] = sym
735+
return sym
736+
else
737+
visited[node] = node
738+
return node
739+
end
740+
end
741+
742+
dfs(graph)
743+
return sorted_nodes
744+
end
745+
699746
function _cse!(mem, expr)
700747
iscall(expr) || return expr
701748
op = _cse!(mem, operation(expr))
@@ -714,12 +761,16 @@ function _cse!(mem, expr)
714761
end
715762

716763
function cse(expr)
717-
state = Dict{Any, Int}()
718-
cse_state!(state, expr)
719-
cse_block(state, expr)
764+
sorted_nodes = topological_sort(expr)
765+
if isempty(sorted_nodes)
766+
return Let(Assignment[], expr)
767+
else
768+
last_assignment = pop!(sorted_nodes)
769+
body = rhs(last_assignment)
770+
return Let(sorted_nodes, body)
771+
end
720772
end
721773

722-
723774
function _cse(exprs::AbstractArray)
724775
letblock = cse(Term{Any}(tuple, vec(exprs)))
725776
letblock.pairs, reshape(arguments(letblock.body), size(exprs))
@@ -746,41 +797,4 @@ function cse(x::MakeSparseArray)
746797
end
747798
end
748799

749-
750-
function cse_state!(state, t)
751-
!iscall(t) && return t
752-
state[t] = Base.get(state, t, 0) + 1
753-
foreach(x->cse_state!(state, x), arguments(t))
754-
end
755-
756-
function cse_block!(assignments, counter, names, name, state, x)
757-
if get(state, x, 0) > 1
758-
if haskey(names, x)
759-
return names[x]
760-
else
761-
sym = Sym{symtype(x)}(Symbol(name, counter[]))
762-
names[x] = sym
763-
push!(assignments, sym x)
764-
counter[] += 1
765-
return sym
766-
end
767-
elseif iscall(x)
768-
args = map(a->cse_block!(assignments, counter, names, name, state,a), arguments(x))
769-
if isterm(x)
770-
return term(operation(x), args...)
771-
else
772-
return maketerm(typeof(x), operation(x), args, metadata(x))
773-
end
774-
else
775-
return x
776-
end
777-
end
778-
779-
function cse_block(state, t, name=Symbol("var-", hash(t)))
780-
assignments = Assignment[]
781-
counter = Ref{Int}(1)
782-
names = Dict{Any, BasicSymbolic}()
783-
Let(assignments, cse_block!(assignments, counter, names, name, state, t))
784-
end
785-
786800
end

test/cse.jl

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,56 @@
11
using SymbolicUtils, SymbolicUtils.Code, Test
2+
using SymbolicUtils.Code: topological_sort
3+
24
@testset "CSE" begin
35
@syms x
46
t = cse(hypot(hypot(cos(x), sin(x)), atan(cos(x), sin(x))))
57

68
@test t isa Let
7-
@test length(t.pairs) == 2
8-
@test occursin(t.pairs[1].lhs, t.body)
9-
@test occursin(t.pairs[2].lhs, t.body)
9+
@test length(t.pairs) == 4
10+
@test occursin(t.pairs[3].lhs, t.body)
11+
@test occursin(t.pairs[4].lhs, t.body)
12+
end
13+
14+
@testset "DAG CSE" begin
15+
@syms a b
16+
expr = sin(a + b) * (a + b)
17+
sorted_nodes = topological_sort(expr)
18+
@test length(sorted_nodes) == 3
19+
@test isequal(sorted_nodes[1].rhs, a + b)
20+
@test isequal(sin(sorted_nodes[1].lhs), sorted_nodes[2].rhs)
21+
22+
expr = (a + b)^(a + b)
23+
sorted_nodes = topological_sort(expr)
24+
@test length(sorted_nodes) == 2
25+
@test isequal(sorted_nodes[1].rhs, a + b)
26+
ab_node = sorted_nodes[1].lhs
27+
@test isequal(ab_node^ab_node, sorted_nodes[2].rhs)
28+
let_expr = cse(expr)
29+
@test length(let_expr.pairs) == 1
30+
@test isequal(let_expr.pairs[1].rhs, a + b)
31+
corresponding_sym = let_expr.pairs[1].lhs
32+
@test isequal(let_expr.body, corresponding_sym^corresponding_sym)
33+
34+
expr = a + b
35+
sorted_nodes = topological_sort(expr)
36+
@test length(sorted_nodes) == 1
37+
@test isequal(sorted_nodes[1].rhs, a + b)
38+
let_expr = cse(expr)
39+
@test isempty(let_expr.pairs)
40+
@test isequal(let_expr.body, a + b)
41+
42+
expr = a
43+
sorted_nodes = topological_sort(expr)
44+
@test isempty(sorted_nodes)
45+
let_expr = cse(expr)
46+
@test isempty(let_expr.pairs)
47+
@test isequal(let_expr.body, a)
48+
49+
# array symbolics
50+
# https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/688#pullrequestreview-2554931739
51+
@syms c
52+
function foo end
53+
ex = term(foo, [a^2 + b^2, b^2 + c], c; type = Real)
54+
sorted_nodes = topological_sort(ex)
55+
@test length(sorted_nodes) == 6
1056
end

0 commit comments

Comments
 (0)