Skip to content

Commit 081e4ae

Browse files
committed
Topological sort directed acyclic graph
1 parent 982e6b1 commit 081e4ae

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

src/code.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,31 @@ end
696696

697697
@inline newsym(::Type{T}) where T = Sym{T}(gensym("cse"))
698698

699+
function topological_sort(graph)
700+
sorted_nodes = Assignment[]
701+
visited = IdDict()
702+
703+
function dfs(node)
704+
if haskey(visited, node)
705+
return visited[node]
706+
end
707+
if iscall(node)
708+
args = map(dfs, arguments(node))
709+
new_node = maketerm(typeof(node), operation(node), args, metadata(node))
710+
sym = newsym(symtype(new_node))
711+
push!(sorted_nodes, sym new_node)
712+
visited[node] = sym
713+
return sym
714+
else
715+
visited[node] = node
716+
return node
717+
end
718+
end
719+
720+
dfs(graph)
721+
return sorted_nodes
722+
end
723+
699724
function _cse!(mem, expr)
700725
iscall(expr) || return expr
701726
op = _cse!(mem, operation(expr))

test/cse.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,13 @@ using SymbolicUtils, SymbolicUtils.Code, Test
88
@test occursin(t.pairs[1].lhs, t.body)
99
@test occursin(t.pairs[2].lhs, t.body)
1010
end
11+
12+
@testset "DAG CSE" begin
13+
@syms a b
14+
expr = sin(a + b) * (a + b)
15+
sorted_nodes = topological_sort(expr)
16+
@test length(sorted_nodes) == 3
17+
expr = (a + b)^(a + b)
18+
sorted_nodes = topological_sort(expr)
19+
@test length(sorted_nodes) == 2
20+
end

0 commit comments

Comments
 (0)