Skip to content

Commit 3147749

Browse files
Merge pull request #699 from AayushSabharwal/as/cse-fix
fix: fix `Code.cse` for different symbolic types
2 parents 0a08221 + bdb8832 commit 3147749

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

src/code.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,9 @@ function topological_sort(graph)
722722
end
723723
if iscall(node)
724724
args = map(dfs, arguments(node))
725-
new_node = maketerm(typeof(node), operation(node), args, metadata(node))
725+
# use `term` instead of `maketerm` because we only care about the operation being performed
726+
# and not the representation. This avoids issues with `newsym` symbols not having sizes, etc.
727+
new_node = term(operation(node), args...)
726728
sym = newsym(symtype(new_node))
727729
push!(sorted_nodes, sym new_node)
728730
visited[node] = sym

test/cse.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,28 @@ end
1616
expr = sin(a + b) * (a + b)
1717
sorted_nodes = topological_sort(expr)
1818
@test length(sorted_nodes) == 3
19-
@test isequal(sorted_nodes[1].rhs, a + b)
19+
@test isequal(sorted_nodes[1].rhs, term(+, a, b))
2020
@test isequal(sin(sorted_nodes[1].lhs), sorted_nodes[2].rhs)
2121

2222
expr = (a + b)^(a + b)
2323
sorted_nodes = topological_sort(expr)
2424
@test length(sorted_nodes) == 2
25-
@test isequal(sorted_nodes[1].rhs, a + b)
25+
@test isequal(sorted_nodes[1].rhs, term(+, a, b))
2626
ab_node = sorted_nodes[1].lhs
27-
@test isequal(ab_node^ab_node, sorted_nodes[2].rhs)
27+
@test isequal(term(^, ab_node, ab_node), sorted_nodes[2].rhs)
2828
let_expr = cse(expr)
2929
@test length(let_expr.pairs) == 1
30-
@test isequal(let_expr.pairs[1].rhs, a + b)
30+
@test isequal(let_expr.pairs[1].rhs, term(+, a, b))
3131
corresponding_sym = let_expr.pairs[1].lhs
32-
@test isequal(let_expr.body, corresponding_sym^corresponding_sym)
32+
@test isequal(let_expr.body, term(^, corresponding_sym, corresponding_sym))
3333

3434
expr = a + b
3535
sorted_nodes = topological_sort(expr)
3636
@test length(sorted_nodes) == 1
37-
@test isequal(sorted_nodes[1].rhs, a + b)
37+
@test isequal(sorted_nodes[1].rhs, term(+, a, b))
3838
let_expr = cse(expr)
3939
@test isempty(let_expr.pairs)
40-
@test isequal(let_expr.body, a + b)
40+
@test isequal(let_expr.body, term(+, a, b))
4141

4242
expr = a
4343
sorted_nodes = topological_sort(expr)

0 commit comments

Comments
 (0)