1
1
module Code
2
2
3
- using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions
3
+ using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions,
4
+ DocStringExtensions
4
5
5
6
export toexpr, Assignment, (← ), Let, Func, DestructuredArgs, LiteralExpr,
6
7
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
696
697
697
698
@inline newsym (:: Type{T} ) where T = Sym {T} (gensym (" cse" ))
698
699
700
+ """
701
+ $(SIGNATURES)
702
+
703
+ Perform a topological sort on a symbolic expression represented as a Directed Acyclic
704
+ Graph (DAG).
705
+
706
+ This function takes a symbolic expression `graph` (potentially containing shared common
707
+ sub-expressions) and returns an array of `Assignment` objects. Each `Assignment`
708
+ represents a node in the sorted order, assigning a fresh symbol to its corresponding
709
+ expression. The order ensures that all dependencies of a node appear before the node itself
710
+ in the array.
711
+
712
+ Hash consing is assumed, meaning that structurally identical expressions are represented by
713
+ the same object in memory. This allows for efficient equality checks using `IdDict`.
714
+ """
715
+ function topological_sort (graph)
716
+ sorted_nodes = Assignment[]
717
+ visited = IdDict ()
718
+
719
+ function dfs (node)
720
+ if haskey (visited, node)
721
+ return visited[node]
722
+ end
723
+ if iscall (node)
724
+ args = map (dfs, arguments (node))
725
+ new_node = maketerm (typeof (node), operation (node), args, metadata (node))
726
+ sym = newsym (symtype (new_node))
727
+ push! (sorted_nodes, sym ← new_node)
728
+ visited[node] = sym
729
+ return sym
730
+ elseif _is_array_of_symbolics (node)
731
+ new_node = map (dfs, node)
732
+ sym = newsym (typeof (new_node))
733
+ push! (sorted_nodes, sym ← new_node)
734
+ visited[node] = sym
735
+ return sym
736
+ else
737
+ visited[node] = node
738
+ return node
739
+ end
740
+ end
741
+
742
+ dfs (graph)
743
+ return sorted_nodes
744
+ end
745
+
699
746
function _cse! (mem, expr)
700
747
iscall (expr) || return expr
701
748
op = _cse! (mem, operation (expr))
@@ -714,12 +761,16 @@ function _cse!(mem, expr)
714
761
end
715
762
716
763
function cse (expr)
717
- state = Dict {Any, Int} ()
718
- cse_state! (state, expr)
719
- cse_block (state, expr)
764
+ sorted_nodes = topological_sort (expr)
765
+ if isempty (sorted_nodes)
766
+ return Let (Assignment[], expr)
767
+ else
768
+ last_assignment = pop! (sorted_nodes)
769
+ body = rhs (last_assignment)
770
+ return Let (sorted_nodes, body)
771
+ end
720
772
end
721
773
722
-
723
774
function _cse (exprs:: AbstractArray )
724
775
letblock = cse (Term {Any} (tuple, vec (exprs)))
725
776
letblock. pairs, reshape (arguments (letblock. body), size (exprs))
@@ -746,41 +797,4 @@ function cse(x::MakeSparseArray)
746
797
end
747
798
end
748
799
749
-
750
- function cse_state! (state, t)
751
- ! iscall (t) && return t
752
- state[t] = Base. get (state, t, 0 ) + 1
753
- foreach (x-> cse_state! (state, x), arguments (t))
754
- end
755
-
756
- function cse_block! (assignments, counter, names, name, state, x)
757
- if get (state, x, 0 ) > 1
758
- if haskey (names, x)
759
- return names[x]
760
- else
761
- sym = Sym {symtype(x)} (Symbol (name, counter[]))
762
- names[x] = sym
763
- push! (assignments, sym ← x)
764
- counter[] += 1
765
- return sym
766
- end
767
- elseif iscall (x)
768
- args = map (a-> cse_block! (assignments, counter, names, name, state,a), arguments (x))
769
- if isterm (x)
770
- return term (operation (x), args... )
771
- else
772
- return maketerm (typeof (x), operation (x), args, metadata (x))
773
- end
774
- else
775
- return x
776
- end
777
- end
778
-
779
- function cse_block (state, t, name= Symbol (" var-" , hash (t)))
780
- assignments = Assignment[]
781
- counter = Ref {Int} (1 )
782
- names = Dict {Any, BasicSymbolic} ()
783
- Let (assignments, cse_block! (assignments, counter, names, name, state, t))
784
- end
785
-
786
800
end
0 commit comments