Skip to content

Commit dd0505a

Browse files
Merge pull request #739 from AayushSabharwal/as/weakset
feat: use `WeakCacheSet` for hashconsing
2 parents 2f8d4fa + 208309c commit dd0505a

File tree

3 files changed

+239
-14
lines changed

3 files changed

+239
-14
lines changed

src/SymbolicUtils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import ExproniconLite as EL
2424
import TaskLocalValues: TaskLocalValue
2525
import WeakValueDicts: WeakValueDict
2626

27+
include("WeakCacheSets.jl")
28+
2729
include("cache.jl")
2830
Base.@deprecate istree iscall
2931

src/WeakCacheSets.jl

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# These can be changed, to trade off better performance for space
2+
const maxallowedprobe = 16
3+
const maxprobeshift = 6
4+
5+
if VERSION < v"1.11"
6+
const Memory = Vector
7+
end
8+
9+
"""
10+
WeakCacheSet()
11+
12+
`WeakCacheSet{K}()` constructs a cache of items of type `K` which are auto-deleted when dead.
13+
Keys are compared with [`isequal`](@ref) and hashed with [`hash`](@ref).
14+
15+
!!! warning
16+
17+
Keys are allowed to be mutable, but if you do mutate stored
18+
keys, the hash table may become internally inconsistent, in which case
19+
the `WeakCacheSet` will not work properly.
20+
21+
"""
22+
mutable struct WeakCacheSet{K}
23+
# Metadata: empty => 0x00, full or removed => 0b1[7 most significant hash bits]
24+
slots::Memory{UInt8}
25+
keys::Memory{WeakRef}
26+
count::Int
27+
maxprobe::Int
28+
29+
function WeakCacheSet{K}() where K
30+
n = 0
31+
slots = Memory{UInt8}(undef,n)
32+
fill!(slots, 0x0)
33+
keys = Memory{WeakRef}(undef, n)
34+
fill!(keys, WeakRef(nothing))
35+
new(slots, keys, 0, 0)
36+
end
37+
end
38+
39+
# Gets 7 most significant bits from the hash (hsh), first bit is 1
40+
_shorthash7(hsh::UInt) = (hsh >> (8sizeof(UInt)-7))%UInt8 | 0x80
41+
42+
# hashindex (key, sz) - computes optimal position and shorthash7
43+
# idx - optimal position in the hash table
44+
# sh::UInt8 - short hash (7 highest hash bits)
45+
function hashindex(key, sz)
46+
hsh = hash2(key)::UInt
47+
idx = (((hsh % Int) & (sz-1)) + 1)::Int
48+
return idx, _shorthash7(hsh)
49+
end
50+
51+
Base.@propagate_inbounds isslotempty(h::WeakCacheSet, i::Int) = h.slots[i] == 0x00
52+
Base.@propagate_inbounds function isslotfilled(h::WeakCacheSet, i::Int)
53+
return (h.slots[i] != 0) && !isnothing(h.keys[i].value)
54+
end
55+
Base.@propagate_inbounds function isslotmissing(h::WeakCacheSet, i::Int)
56+
return isnothing(h.keys[i].value)
57+
end
58+
_tablesz(x::T) where T <: Integer = x < 16 ? T(16) : one(T)<<(Base.top_set_bit(x-one(T)))
59+
function rehash!(h::WeakCacheSet{K}, newsz = length(h.keys)) where K
60+
olds = h.slots
61+
oldk = h.keys
62+
sz = length(olds)
63+
newsz = _tablesz(newsz)
64+
if h.count == 0
65+
# TODO: tryresize
66+
h.slots = Memory{UInt8}(undef, newsz)
67+
fill!(h.slots, 0x0)
68+
h.keys = Memory{WeakRef}(undef, newsz)
69+
fill!(h.keys, WeakRef(nothing))
70+
h.maxprobe = 0
71+
return h
72+
end
73+
74+
slots = Memory{UInt8}(undef, newsz)
75+
fill!(slots, 0x0)
76+
keys = Memory{WeakRef}(undef, newsz)
77+
fill!(keys, WeakRef(nothing))
78+
count = 0
79+
maxprobe = 0
80+
81+
for i = 1:sz
82+
@inbounds if olds[i] != 0
83+
k = oldk[i].value::Union{K, Nothing}
84+
isnothing(k) && continue
85+
index, sh = hashindex(k, newsz)
86+
index0 = index
87+
while slots[index] != 0
88+
index = (index & (newsz-1)) + 1
89+
end
90+
probe = (index - index0) & (newsz-1)
91+
maxprobe = max(maxprobe, probe)
92+
slots[index] = olds[i]
93+
keys[index] = WeakRef(k)
94+
count += 1
95+
end
96+
end
97+
98+
h.slots = slots
99+
h.keys = keys
100+
h.count = count
101+
h.maxprobe = maxprobe
102+
return h
103+
end
104+
105+
function Base.sizehint!(d::WeakCacheSet{T}, newsz; shrink::Bool=true) where T
106+
oldsz = length(d.slots)
107+
# limit new element count to max_values of the key type
108+
newsz = min(max(newsz, length(d)), Base.max_values(T)::Int)
109+
# need at least 1.5n space to hold n elements
110+
newsz = _tablesz(cld(3 * newsz, 2))
111+
return (shrink ? newsz == oldsz : newsz <= oldsz) ? d : rehash!(d, newsz)
112+
end
113+
114+
# get (index, sh) for the key
115+
# index - where a key is stored, or -pos if not present
116+
# and the key would be inserted at pos
117+
# sh::UInt8 - short hash (7 highest hash bits)
118+
function ht_keyindex2_shorthash!(h::WeakCacheSet{K}, key) where K
119+
sz = length(h.keys)
120+
if sz == 0 # if Dict was empty resize and then return location to insert
121+
rehash!(h, 4)
122+
index, sh = hashindex(key, length(h.keys))
123+
return -index, sh
124+
end
125+
iter = 0
126+
maxprobe = h.maxprobe
127+
index, sh = hashindex(key, sz)
128+
avail = 0
129+
keys = h.keys
130+
131+
@inbounds while true
132+
if isslotempty(h,index)
133+
return (avail < 0 ? avail : -index), sh
134+
end
135+
136+
if isslotmissing(h,index)
137+
if avail == 0
138+
# found an available slot, but need to keep scanning
139+
# in case "key" already exists in a later collided slot.
140+
avail = -index
141+
end
142+
elseif h.slots[index] == sh
143+
k = keys[index].value
144+
if key === k || isequal_with_metadata(key, k)
145+
return index, sh
146+
end
147+
end
148+
149+
index = (index & (sz-1)) + 1
150+
iter += 1
151+
iter > maxprobe && break
152+
end
153+
154+
avail < 0 && return avail, sh
155+
156+
maxallowed = max(maxallowedprobe, sz>>maxprobeshift)
157+
# Check if key is not present, may need to keep searching to find slot
158+
@inbounds while iter < maxallowed
159+
if !isslotfilled(h,index)
160+
h.maxprobe = iter
161+
return -index, sh
162+
end
163+
index = (index & (sz-1)) + 1
164+
iter += 1
165+
end
166+
167+
rehash!(h, h.count > 64000 ? sz*2 : sz*4)
168+
169+
return ht_keyindex2_shorthash!(h, key)
170+
end
171+
172+
Base.@propagate_inbounds function _setindex!(h::WeakCacheSet, key, index, sh = _shorthash7(hash2(key)))
173+
h.slots[index] = sh
174+
h.keys[index] = WeakRef(key)
175+
h.count += !isslotmissing(h, index)
176+
177+
sz = length(h.keys)
178+
# Rehash now if necessary
179+
if h.count*3 > sz*2
180+
# > 2/3 full (including tombstones)
181+
rehash!(h, h.count > 64000 ? h.count*2 : max(h.count*4, 4))
182+
end
183+
nothing
184+
end
185+
186+
function getkey!(h::WeakCacheSet{K}, key0) where K
187+
if key0 isa K
188+
key = key0
189+
else
190+
key = convert(K, key0)::K
191+
if !(isequal_with_metadata(key, key0)::Bool)
192+
throw(KeyTypeError(K, key0))
193+
end
194+
end
195+
getkey!(h, v0, key)
196+
end
197+
198+
function getkey!(h::WeakCacheSet{K}, key::K) where K
199+
index, sh = ht_keyindex2_shorthash!(h, key)
200+
if index > 0
201+
foundkey = h.keys[index].value::Union{K, Nothing}
202+
if isnothing(foundkey)
203+
@inbounds h.keys[index] = key
204+
return key
205+
end
206+
return foundkey
207+
else
208+
@inbounds _setindex!(h, key, -index, sh)
209+
return key
210+
end
211+
end
212+
213+
function Base.show(io::IO, t::WeakCacheSet{K}) where K
214+
recur_io = IOContext(io, :SHOWN_SET => t,
215+
:typeinfo => K)
216+
217+
limit = get(io, :limit, false)::Bool
218+
print(io, "WeakCacheSet(")
219+
n = 0
220+
first = true
221+
for val in t.keys
222+
realval = val.value::Union{K, Nothing}
223+
isnothing(realval) && continue
224+
first || print(io, ", ")
225+
first = false
226+
show(recur_io, realval)
227+
n+=1
228+
limit && n >= 10 && (print(io, ""); break)
229+
end
230+
print(io, ')')
231+
end

src/types.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function exprtype(x::BasicSymbolic)
8888
end
8989
end
9090

91-
const wvd = TaskLocalValue{WeakValueDict{UInt, BasicSymbolic}}(WeakValueDict{UInt, BasicSymbolic})
91+
const wcs = TaskLocalValue{WeakCacheSet{BasicSymbolic}}(WeakCacheSet{BasicSymbolic})
9292

9393
# Same but different error messages
9494
@noinline error_on_type() = error("Internal error: unreachable reached!")
@@ -547,20 +547,12 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
547547
if !ENABLE_HASHCONSING[]
548548
return s
549549
end
550-
cache = wvd[]
551-
h = hash2(s)
552-
k = get!(cache, h, s)
553-
if isequal_with_metadata(k, s)
554-
if iszero(k.id[])
555-
k.id[] = @atomic ID_COUNTER.x += 1
556-
end
557-
return k
558-
else
559-
if iszero(s.id[])
560-
s.id[] = @atomic ID_COUNTER.x += 1
561-
end
562-
return s
550+
cache = wcs[]
551+
k = getkey!(cache, s)
552+
if iszero(k.id[])
553+
k.id[] = @atomic ID_COUNTER.x += 1
563554
end
555+
return k
564556
end
565557

566558
function Sym{T}(name::Symbol; kw...) where {T}

0 commit comments

Comments
 (0)