Skip to content

Commit 329595e

Browse files
author
a
committed
use UInt as id
1 parent 0bd3021 commit 329595e

File tree

8 files changed

+94
-85
lines changed

8 files changed

+94
-85
lines changed

src/EGraphs/egraph.jl

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@ Given an ENode `n`, `make` should return the corresponding analysis value.
2929
"""
3030
function make end
3131

32-
const EClassId = Int64
33-
const TermTypes = Dict{Tuple{Any,Int},Type}
32+
const EClassId = UInt64
3433
# TODO document bindings
35-
const Bindings = Base.ImmutableDict{Int,Tuple{Int,Int}}
36-
const UNDEF_ARGS = Vector{EClassId}(undef, 0)
34+
const Bindings = Base.ImmutableDict{Int,Tuple{EClassId,Int}}
35+
const UNDEF_ID_VEC = Vector{EClassId}(undef, 0)
3736

3837
# @compactify begin
3938
struct ENode
@@ -44,15 +43,15 @@ struct ENode
4443
args::Vector{EClassId}
4544
hash::Ref{UInt}
4645
ENode(head, operation, args) = new(true, head, operation, args, Ref{UInt}(0))
47-
ENode(literal) = new(false, nothing, literal, UNDEF_ARGS, Ref{UInt}(0))
46+
ENode(literal) = new(false, nothing, literal, UNDEF_ID_VEC, Ref{UInt}(0))
4847
end
4948

5049
TermInterface.istree(n::ENode) = n.istree
5150
TermInterface.head(n::ENode) = n.head
5251
TermInterface.operation(n::ENode) = n.operation
5352
TermInterface.arguments(n::ENode) = n.args
5453
TermInterface.children(n::ENode) = [n.operation; n.args...]
55-
TermInterface.arity(n::ENode) = length(n.args)
54+
TermInterface.arity(n::ENode)::Int = length(n.args)
5655

5756

5857
# This optimization comes from SymbolicUtils
@@ -78,7 +77,7 @@ end
7877

7978
Base.show(io::IO, x::ENode) = print(io, to_expr(x))
8079

81-
function op_key(n)
80+
function op_key(n)::Pair{Any,Int}
8281
op = operation(n)
8382
(op isa Union{Function,DataType} ? nameof(op) : op) => (istree(n) ? arity(n) : -1)
8483
end
@@ -155,7 +154,7 @@ mutable struct EGraph{Head,Analysis}
155154
"Buffer for e-matching which defaults to a global. Use a local buffer for generated functions."
156155
buffer::Vector{Bindings}
157156
"Buffer for rule application which defaults to a global. Use a local buffer for generated functions."
158-
merges_buffer::Vector{Tuple{Int,Int}}
157+
merges_buffer::Vector{EClassId}
159158
lock::ReentrantLock
160159
end
161160

@@ -167,16 +166,16 @@ Construct an EGraph from a starting symbolic expression `expr`.
167166
function EGraph{Head,Analysis}(; needslock::Bool = false) where {Head,Analysis}
168167
EGraph{Head,Analysis}(
169168
UnionFind(),
170-
Dict{EClassId,EClass}(),
169+
Dict{EClassId,EClass{Analysis}}(),
171170
Dict{ENode,EClassId}(),
172171
Pair{ENode,EClassId}[],
173172
UniqueQueue{Pair{ENode,EClassId}}(),
174-
-1,
173+
0,
175174
Dict{Pair{Any,Int},Vector{EClassId}}(),
176175
false,
177176
needslock,
178177
Bindings[],
179-
Tuple{Int,Int}[],
178+
EClassId[],
180179
ReentrantLock(),
181180
)
182181
end
@@ -232,7 +231,7 @@ end
232231

233232
function lookup(g::EGraph, n::ENode)::EClassId
234233
cc = canonicalize(g, n)
235-
haskey(g.memo, cc) ? find(g, g.memo[cc]) : -1
234+
haskey(g.memo, cc) ? find(g, g.memo[cc]) : 0
236235
end
237236

238237

@@ -288,26 +287,22 @@ Recursively traverse an type satisfying the `TermInterface` and insert terms int
288287
[`EGraph`](@ref). If `e` has no children (has an arity of 0) then directly
289288
insert the literal into the [`EGraph`](@ref).
290289
"""
291-
function addexpr!(g::EGraph, se, keepmeta = false)::EClassId
290+
function addexpr!(g::EGraph, se)::EClassId
292291
se isa EClass && return se.id
293292
e = preprocess(se)
294293

295294
n = if istree(se)
296295
args = arguments(e)
297-
ar = length(args)
296+
ar = arity(e)
298297
class_ids = Vector{EClassId}(undef, ar)
299298
for i in 1:ar
300-
@inbounds class_ids[i] = addexpr!(g, args[i], keepmeta)
299+
@inbounds class_ids[i] = addexpr!(g, args[i])
301300
end
302301
ENode(head(e), operation(e), class_ids)
303302
else # constant enode
304303
ENode(e)
305304
end
306305
id = add!(g, n)
307-
if keepmeta
308-
meta = TermInterface.metadata(e)
309-
!isnothing(meta) && setdata!(g.classes[id], :metadata_analysis, meta)
310-
end
311306
return id
312307
end
313308

@@ -512,15 +507,18 @@ function lookup_pat(g::EGraph{Head}, p::PatTerm)::EClassId where {Head}
512507

513508
eh = Head(head_symbol(head(p)))
514509

515-
ids = map(x -> lookup_pat(g, x), args)
516-
!all((>)(0), ids) && return -1
510+
ids = Vector{EClassId}(undef, ar)
511+
for i in 1:ar
512+
@inbounds ids[i] = lookup_pat(g, args[i])
513+
ids[i] <= 0 && return 0
514+
end
517515

518516
if Head == ExprHead && op isa Union{Function,DataType}
519517
id = lookup(g, ENode(eh, op, ids))
520-
id < 0 ? lookup(g, ENode(eh, nameof(op), ids)) : id
518+
id <= 0 ? lookup(g, ENode(eh, nameof(op), ids)) : id
521519
else
522520
lookup(g, ENode(eh, op, ids))
523521
end
524522
end
525523

526-
lookup_pat(g::EGraph, p::Any) = lookup(g, ENode(p))
524+
lookup_pat(g::EGraph, p::Any)::EClassId = lookup(g, ENode(p))

src/EGraphs/saturation.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ Base.@kwdef mutable struct SaturationParams
3838
timer::Bool = true
3939
end
4040

41-
function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64}
41+
function cached_ids(g::EGraph, p::PatTerm)::Vector{EClassId}
4242
if isground(p)
4343
id = lookup_pat(g, p)
4444
!isnothing(id) && return [id]
4545
else
46-
get(g.classes_by_op, op_key(p), ())
46+
get(g.classes_by_op, op_key(p), UNDEF_ID_VEC)
4747
end
4848
end
4949

@@ -115,13 +115,15 @@ function instantiate_enode!(bindings::Bindings, g::EGraph{Head}, p::PatTerm)::EC
115115
end
116116

117117
function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction)
118-
push!(g.merges_buffer, (id, instantiate_enode!(buf, g, rule.right)))
118+
push!(g.merges_buffer, id)
119+
push!(g.merges_buffer, instantiate_enode!(buf, g, rule.right))
119120
nothing
120121
end
121122

122123
function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int)
123124
pat_to_inst = direction == 1 ? rule.right : rule.left
124-
push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst)))
125+
push!(g.merges_buffer, id)
126+
push!(g.merges_buffer, instantiate_enode!(bindings, g, pat_to_inst))
125127
nothing
126128
end
127129

@@ -156,7 +158,8 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClas
156158
r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...)
157159
isnothing(r) && return nothing
158160
rcid = addexpr!(g, r)
159-
push!(g.merges_buffer, (id, rcid))
161+
push!(g.merges_buffer, id)
162+
push!(g.merges_buffer, rcid)
160163
return nothing
161164
end
162165

@@ -177,7 +180,7 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
177180
end
178181

179182
bindings = pop!(g.buffer)
180-
rule_idx, id = bindings[0]
183+
id, rule_idx = bindings[0]
181184
direction = sign(rule_idx)
182185
rule_idx = abs(rule_idx)
183186
rule = theory[rule_idx]
@@ -198,7 +201,8 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
198201
end
199202
maybelock!(g) do
200203
while !isempty(g.merges_buffer)
201-
(l, r) = pop!(g.merges_buffer)
204+
l = pop!(g.merges_buffer)
205+
r = pop!(g.merges_buffer)
202206
union!(g, l, r)
203207
end
204208
end

src/EGraphs/unionfind.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
struct UnionFind
2-
parents::Vector{Int}
2+
parents::Vector{UInt}
33
end
44

5-
UnionFind() = UnionFind(Int[])
5+
UnionFind() = UnionFind(UInt[])
66

7-
function Base.push!(uf::UnionFind)
7+
function Base.push!(uf::UnionFind)::UInt
88
l = length(uf.parents) + 1
99
push!(uf.parents, l)
1010
l
1111
end
1212

1313
Base.length(uf::UnionFind) = length(uf.parents)
1414

15-
function Base.union!(uf::UnionFind, i::Int, j::Int)
15+
function Base.union!(uf::UnionFind, i::UInt, j::UInt)
1616
uf.parents[j] = i
1717
i
1818
end
1919

20-
function find(uf::UnionFind, i::Int)
20+
function find(uf::UnionFind, i::UInt)
2121
while i != uf.parents[i]
2222
i = uf.parents[i]
2323
end

src/Patterns.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ abstract type AbstractPat end
1313
struct PatHead
1414
head
1515
end
16-
TermInterface.head_symbol(p::PatHead) = p.head
16+
TermInterface.head_symbol(p::PatHead)::Symbol = p.head
1717

1818
PatHead(p::PatHead) = error("recursive!")
1919

@@ -83,34 +83,38 @@ symbol `operation` and expression head `head.head`.
8383
struct PatTerm <: AbstractPat
8484
head::PatHead
8585
children::Vector
86-
PatTerm(h, t::Vector) = new(h, t)
86+
isground::Bool
87+
PatTerm(h, t::Vector) = new(h, t, all(isground, t))
8788
end
8889
PatTerm(eh, op) = PatTerm(eh, [op])
8990
PatTerm(eh, children...) = PatTerm(eh, collect(children))
91+
92+
isground(p::PatTerm)::Bool = p.isground
93+
9094
TermInterface.istree(::PatTerm) = true
9195
TermInterface.head(p::PatTerm)::PatHead = p.head
9296
TermInterface.children(p::PatTerm) = p.children
9397
function TermInterface.operation(p::PatTerm)
9498
hs = head_symbol(head(p))
95-
hs == :call && return first(p.children)
99+
hs in (:call, :macrocall) && return first(p.children)
96100
# hs == :ref && return getindex
97101
hs
98102
end
99103
function TermInterface.arguments(p::PatTerm)
100104
hs = head_symbol(head(p))
101-
hs == :call ? @view(p.children[2:end]) : p.children
105+
hs in (:call, :macrocall) ? @view(p.children[2:end]) : p.children
106+
end
107+
function TermInterface.arity(p::PatTerm)
108+
hs = head_symbol(head(p))
109+
l = length(p.children)
110+
hs in (:call, :macrocall) ? l - 1 : l
102111
end
103-
TermInterface.arity(p::PatTerm) = length(arguments(p))
104112
TermInterface.metadata(p::PatTerm) = nothing
105113

106114
TermInterface.maketerm(head::PatHead, children; type = Any, metadata = nothing) = PatTerm(head, children...)
107115

108-
isground(p::PatTerm) = all(isground, p.children)
109-
110-
111-
# ==============================================
112-
# ================== PATTERN VARIABLES =========
113-
# ==============================================
116+
# ---------------------
117+
# # Pattern Variables.
114118

115119
"""
116120
Collects pattern variables appearing in a pattern into a vector of symbols
@@ -122,9 +126,9 @@ patvars(x, s) = s
122126
patvars(p) = unique!(patvars(p, Symbol[]))
123127

124128

125-
# ==============================================
126-
# ================== DEBRUJIN INDEXING =========
127-
# ==============================================
129+
# ---------------------
130+
# # Debrujin Indexing.
131+
128132

129133
function setdebrujin!(p::Union{PatVar,PatSegment}, pvars)
130134
p.idx = findfirst((==)(p.name), pvars)

src/TermInterface.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ export unsorted_arguments
114114
Returns the number of arguments of `x`. Implicitly defined
115115
if `arguments(x)` is defined.
116116
"""
117-
arity(x) = length(arguments(x))
117+
arity(x)::Int = length(arguments(x))
118118
export arity
119119

120120

@@ -220,7 +220,7 @@ struct ExprHead
220220
end
221221
export ExprHead
222222

223-
head_symbol(eh::ExprHead) = eh.head
223+
head_symbol(eh::ExprHead)::Symbol = eh.head
224224

225225
istree(x::Expr) = true
226226
head(e::Expr) = ExprHead(e.head)
@@ -247,6 +247,11 @@ function arguments(e::Expr)
247247
end
248248
end
249249

250+
function arity(e::Expr)::Int
251+
l = length(e.args)
252+
e.head in (:call, :macrocall) ? l - 1 : l
253+
end
254+
250255
function maketerm(head::ExprHead, children; type = Any, metadata = nothing)
251256
if !isempty(children) && first(children) isa Union{Function,DataType}
252257
Expr(head.head, nameof(first(children)), @view(children[2:end])...)

src/ematch_compiler.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232
function predicate_ematcher(p::PatVar, pred)
3333
function predicate_ematcher(next, g, data, bindings)
3434
!islist(data) && return
35-
id::Int = car(data)
35+
id::UInt = car(data)
3636
eclass = g[id]
3737
if pred(eclass)
3838
enode_idx = 0
@@ -122,27 +122,27 @@ function ematcher(p::PatTerm)
122122
end
123123

124124

125-
const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int,Int}}()
125+
const EMPTY_BINDINGS = Base.ImmutableDict{Int,Tuple{UInt,Int}}()
126126

127127
"""
128-
Substitutions are efficiently represented in memory as vector of tuples of two integers.
129-
This should allow for static allocation of matches and use of LoopVectorization.jl
130-
The buffer has to be fairly big when e-matching.
131-
The size of the buffer should double when there's too many matches.
132-
The format is as follows
133-
* The first pair denotes the index of the rule in the theory and the e-class id
134-
of the node of the e-graph that is being substituted. The rule number should be negative if it's a bidirectional
135-
the direction is right-to-left.
136-
* From the second pair on, it represents (e-class id, literal position) at the position of the pattern variable
137-
* The end of a substitution is delimited by (0,0)
128+
Substitutions are efficiently represented in memory as immutable dictionaries of tuples of two integers.
129+
130+
The format is as follows:
131+
132+
bindings[0] holds
133+
1. e-class-id of the node of the e-graph that is being substituted.
134+
2. the index of the rule in the theory. The rule number should be negative
135+
if it's a bidirectional rule and the direction is right-to-left.
136+
137+
The rest of the immutable dictionary bindings[n>0] represents (e-class id, literal position) at the position of the pattern variable `n`.
138138
"""
139139
function ematcher_yield(p, npvars::Int, direction::Int)
140140
em = ematcher(p)
141141
function ematcher_yield(g, rule_idx, id)::Int
142142
n_matches = 0
143-
em(g, (id,), EMPTY_ECLASS_DICT) do b, n
143+
em(g, (id,), EMPTY_BINDINGS) do b, n
144144
maybelock!(g) do
145-
push!(g.buffer, assoc(b, 0, (rule_idx * direction, id)))
145+
push!(g.buffer, assoc(b, 0, (id, rule_idx * direction)))
146146
n_matches += 1
147147
end
148148
end

0 commit comments

Comments
 (0)