Skip to content

Commit 624e734

Browse files
authored
fix numerous issues with WeakKeyDict (#38180)
Delay cleanup of WeakKeyDict items until the next insertion. And fix `get!`, since previously usage of it would have added keys without finalizers to the dict. Fixes #26939
1 parent aa2a35a commit 624e734

File tree

4 files changed

+185
-56
lines changed

4 files changed

+185
-56
lines changed

base/abstractdict.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@ function in(p, a::AbstractDict)
3131
end
3232

3333
function summary(io::IO, t::AbstractDict)
34-
n = length(t)
3534
showarg(io, t, true)
36-
print(io, " with ", n, (n==1 ? " entry" : " entries"))
35+
if Base.IteratorSize(t) isa HasLength
36+
n = length(t)
37+
print(io, " with ", n, (n==1 ? " entry" : " entries"))
38+
else
39+
print(io, "(...)")
40+
end
3741
end
3842

3943
struct KeySet{K, T <: AbstractDict{K}} <: AbstractSet{K}

base/weakkeydict.jl

Lines changed: 126 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,19 @@ references to objects which may be garbage collected even when
1010
referenced in a hash table.
1111
1212
See [`Dict`](@ref) for further help. Note, unlike [`Dict`](@ref),
13-
`WeakKeyDict` does not convert keys on insertion.
13+
`WeakKeyDict` does not convert keys on insertion, as this would imply the key
14+
object was unreferenced anywhere before insertion.
1415
"""
1516
mutable struct WeakKeyDict{K,V} <: AbstractDict{K,V}
1617
ht::Dict{WeakRef,V}
1718
lock::ReentrantLock
1819
finalizer::Function
20+
dirty::Bool
1921

2022
# Constructors mirror Dict's
2123
function WeakKeyDict{K,V}() where V where K
22-
t = new(Dict{Any,V}(), ReentrantLock(), identity)
23-
t.finalizer = function (k)
24-
# when a weak key is finalized, remove from dictionary if it is still there
25-
if islocked(t)
26-
finalizer(t.finalizer, k)
27-
return nothing
28-
end
29-
delete!(t, k)
30-
end
24+
t = new(Dict{Any,V}(), ReentrantLock(), identity, 0)
25+
t.finalizer = k -> t.dirty = true
3126
return t
3227
end
3328
end
@@ -69,56 +64,149 @@ function WeakKeyDict(kv)
6964
end
7065
end
7166

67+
function _cleanup_locked(h::WeakKeyDict)
68+
if h.dirty
69+
h.dirty = false
70+
idx = skip_deleted_floor!(h.ht)
71+
while idx != 0
72+
if h.ht.keys[idx].value === nothing
73+
_delete!(h.ht, idx)
74+
end
75+
idx = skip_deleted(h.ht, idx + 1)
76+
end
77+
end
78+
return h
79+
end
80+
7281
sizehint!(d::WeakKeyDict, newsz) = sizehint!(d.ht, newsz)
7382
empty(d::WeakKeyDict, ::Type{K}, ::Type{V}) where {K, V} = WeakKeyDict{K, V}()
7483

84+
IteratorSize(::Type{<:WeakKeyDict}) = SizeUnknown()
85+
7586
islocked(wkh::WeakKeyDict) = islocked(wkh.lock)
7687
lock(f, wkh::WeakKeyDict) = lock(f, wkh.lock)
7788
trylock(f, wkh::WeakKeyDict) = trylock(f, wkh.lock)
7889

7990
function setindex!(wkh::WeakKeyDict{K}, v, key) where K
8091
!isa(key, K) && throw(ArgumentError("$(limitrepr(key)) is not a valid key for type $K"))
81-
finalizer(wkh.finalizer, key)
92+
# 'nothing' is not valid both because 'finalizer' will reject it,
93+
# and because we therefore use it as a sentinel value
94+
key === nothing && throw(ArgumentError("`nothing` is not a valid WeakKeyDict key"))
8295
lock(wkh) do
83-
wkh.ht[WeakRef(key)] = v
96+
_cleanup_locked(wkh)
97+
k = getkey(wkh.ht, key, nothing)
98+
if k === nothing
99+
finalizer(wkh.finalizer, key)
100+
k = WeakRef(key)
101+
else
102+
k.value = key
103+
end
104+
wkh.ht[k] = v
84105
end
85106
return wkh
86107
end
108+
function get!(wkh::WeakKeyDict{K}, key, default) where {K}
109+
v = lock(wkh) do
110+
if key !== nothing && haskey(wkh.ht, key)
111+
wkh.ht[key]
112+
else
113+
wkh[key] = default
114+
end
115+
end
116+
return v
117+
end
118+
function get!(default::Callable, wkh::WeakKeyDict{K}, key) where {K}
119+
v = lock(wkh) do
120+
if key !== nothing && haskey(wkh.ht, key)
121+
wkh.ht[key]
122+
else
123+
wkh[key] = default()
124+
end
125+
end
126+
return v
127+
end
87128

88129
function getkey(wkh::WeakKeyDict{K}, kk, default) where K
89-
return lock(wkh) do
90-
k = getkey(wkh.ht, kk, secret_table_token)
91-
k === secret_table_token && return default
92-
return k.value::K
130+
k = lock(wkh) do
131+
k = getkey(wkh.ht, kk, nothing)
132+
k === nothing && return nothing
133+
return k.value
93134
end
135+
return k === nothing ? default : k::K
94136
end
95137

96-
map!(f,iter::ValueIterator{<:WeakKeyDict})= map!(f, values(iter.dict.ht))
97-
get(wkh::WeakKeyDict{K}, key, default) where {K} = lock(() -> get(wkh.ht, key, default), wkh)
98-
get(default::Callable, wkh::WeakKeyDict{K}, key) where {K} = lock(() -> get(default, wkh.ht, key), wkh)
99-
function get!(wkh::WeakKeyDict{K}, key, default) where {K}
100-
!isa(key, K) && throw(ArgumentError("$(limitrepr(key)) is not a valid key for type $K"))
101-
lock(() -> get!(wkh.ht, WeakRef(key), default), wkh)
138+
map!(f, iter::ValueIterator{<:WeakKeyDict})= map!(f, values(iter.dict.ht))
139+
140+
function get(wkh::WeakKeyDict{K}, key, default) where {K}
141+
key === nothing && throw(KeyError(nothing))
142+
lock(wkh) do
143+
return get(wkh.ht, key, default)
144+
end
102145
end
103-
function get!(default::Callable, wkh::WeakKeyDict{K}, key) where {K}
104-
!isa(key, K) && throw(ArgumentError("$(limitrepr(key)) is not a valid key for type $K"))
105-
lock(() -> get!(default, wkh.ht, WeakRef(key)), wkh)
146+
function get(default::Callable, wkh::WeakKeyDict{K}, key) where {K}
147+
key === nothing && throw(KeyError(nothing))
148+
lock(wkh) do
149+
return get(default, wkh.ht, key)
150+
end
151+
end
152+
function pop!(wkh::WeakKeyDict{K}, key) where {K}
153+
key === nothing && throw(KeyError(nothing))
154+
lock(wkh) do
155+
return pop!(wkh.ht, key)
156+
end
157+
end
158+
function pop!(wkh::WeakKeyDict{K}, key, default) where {K}
159+
key === nothing && return default
160+
lock(wkh) do
161+
return pop!(wkh.ht, key, default)
162+
end
163+
end
164+
function delete!(wkh::WeakKeyDict, key)
165+
key === nothing && return wkh
166+
lock(wkh) do
167+
delete!(wkh.ht, key)
168+
end
169+
return wkh
170+
end
171+
function empty!(wkh::WeakKeyDict)
172+
lock(wkh) do
173+
empty!(wkh.ht)
174+
end
175+
return wkh
176+
end
177+
function haskey(wkh::WeakKeyDict{K}, key) where {K}
178+
key === nothing && return false
179+
lock(wkh) do
180+
return haskey(wkh.ht, key)
181+
end
182+
end
183+
function getindex(wkh::WeakKeyDict{K}, key) where {K}
184+
key === nothing && throw(KeyError(nothing))
185+
lock(wkh) do
186+
return getindex(wkh.ht, key)
187+
end
188+
end
189+
isempty(wkh::WeakKeyDict) = length(wkh) == 0
190+
function length(t::WeakKeyDict)
191+
lock(t) do
192+
_cleanup_locked(t)
193+
return length(t.ht)
194+
end
106195
end
107-
pop!(wkh::WeakKeyDict{K}, key) where {K} = lock(() -> pop!(wkh.ht, key), wkh)
108-
pop!(wkh::WeakKeyDict{K}, key, default) where {K} = lock(() -> pop!(wkh.ht, key, default), wkh)
109-
delete!(wkh::WeakKeyDict, key) = (lock(() -> delete!(wkh.ht, key), wkh); wkh)
110-
empty!(wkh::WeakKeyDict) = (lock(() -> empty!(wkh.ht), wkh); wkh)
111-
haskey(wkh::WeakKeyDict{K}, key) where {K} = lock(() -> haskey(wkh.ht, key), wkh)
112-
getindex(wkh::WeakKeyDict{K}, key) where {K} = lock(() -> getindex(wkh.ht, key), wkh)
113-
isempty(wkh::WeakKeyDict) = isempty(wkh.ht)
114-
length(t::WeakKeyDict) = length(t.ht)
115196

116197
function iterate(t::WeakKeyDict{K,V}, state...) where {K, V}
117-
y = lock(() -> iterate(t.ht, state...), t)
118-
y === nothing && return nothing
119-
wkv, newstate = y
120-
kv = Pair{K,V}(wkv[1].value::K, wkv[2])
121-
return (kv, newstate)
198+
return lock(t) do
199+
while true
200+
y = iterate(t.ht, state...)
201+
y === nothing && return nothing
202+
wkv, state = y
203+
k = wkv[1].value
204+
GC.safepoint() # ensure `k` is now gc-rooted
205+
k === nothing && continue # indicates `k` is scheduled for deletion
206+
kv = Pair{K,V}(k::K, wkv[2])
207+
return (kv, state)
208+
end
209+
end
122210
end
123211

124212
filter!(f, d::WeakKeyDict) = filter_in_one_pass!(f, d)

src/gc.c

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,20 @@ JL_DLLEXPORT jl_weakref_t *jl_gc_new_weakref_th(jl_ptls_t ptls,
868868
return wr;
869869
}
870870

871+
static void clear_weak_refs(void)
872+
{
873+
for (int i = 0; i < jl_n_threads; i++) {
874+
jl_ptls_t ptls2 = jl_all_tls_states[i];
875+
size_t n, l = ptls2->heap.weak_refs.len;
876+
void **lst = ptls2->heap.weak_refs.items;
877+
for (n = 0; n < l; n++) {
878+
jl_weakref_t *wr = (jl_weakref_t*)lst[n];
879+
if (!gc_marked(jl_astaggedvalue(wr->value)->bits.gc))
880+
wr->value = (jl_value_t*)jl_nothing;
881+
}
882+
}
883+
}
884+
871885
static void sweep_weak_refs(void)
872886
{
873887
for (int i = 0; i < jl_n_threads; i++) {
@@ -880,16 +894,10 @@ static void sweep_weak_refs(void)
880894
continue;
881895
while (1) {
882896
jl_weakref_t *wr = (jl_weakref_t*)lst[n];
883-
if (gc_marked(jl_astaggedvalue(wr)->bits.gc)) {
884-
// weakref itself is alive,
885-
// so the user could still re-set it to a new value
886-
if (!gc_marked(jl_astaggedvalue(wr->value)->bits.gc))
887-
wr->value = (jl_value_t*)jl_nothing;
897+
if (gc_marked(jl_astaggedvalue(wr)->bits.gc))
888898
n++;
889-
}
890-
else {
899+
else
891900
ndel++;
892-
}
893901
if (n >= l - ndel)
894902
break;
895903
void *tmp = lst[n];
@@ -900,6 +908,7 @@ static void sweep_weak_refs(void)
900908
}
901909
}
902910

911+
903912
// big value list
904913

905914
// Size includes the tag and the tag is not cleared!!
@@ -3003,6 +3012,7 @@ static int _jl_gc_collect(jl_ptls_t ptls, jl_gc_collection_t collection)
30033012
// marking is over
30043013

30053014
// 4. check for objects to finalize
3015+
clear_weak_refs();
30063016
// Record the length of the marked list since we need to
30073017
// mark the object moved to the marked list from the
30083018
// `finalizer_list` by `sweep_finalizer_list`

test/dict.jl

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -893,15 +893,40 @@ Dict(1 => rand(2,3), 'c' => "asdf") # just make sure this does not trigger a dep
893893

894894
# issue #26939
895895
d26939 = WeakKeyDict()
896-
d26939[big"1.0" + 1.1] = 1
897-
GC.gc() # make sure this doesn't segfault
896+
(@noinline d -> d[big"1.0" + 1.1] = 1)(d26939)
897+
GC.gc() # primarily to make sure this doesn't segfault
898+
@test count(d26939) == 0
899+
@test length(d26939.ht) == 1
900+
@test length(d26939) == 0
901+
@test isempty(d26939)
902+
empty!(d26939)
903+
for i in 1:8
904+
(@noinline (d, i) -> d[big(i + 12345)] = 1)(d26939, i)
905+
end
906+
lock(GC.gc, d26939)
907+
@test length(d26939.ht) == 8
908+
@test count(d26939) == 0
909+
@test !haskey(d26939, nothing)
910+
@test_throws KeyError(nothing) d26939[nothing]
911+
@test_throws KeyError(nothing) get(d26939, nothing, 1)
912+
@test_throws KeyError(nothing) get(() -> 1, d26939, nothing)
913+
@test_throws KeyError(nothing) pop!(d26939, nothing)
914+
@test getkey(d26939, nothing, 321) === 321
915+
@test pop!(d26939, nothing, 321) === 321
916+
@test delete!(d26939, nothing) === d26939
917+
@test length(d26939.ht) == 8
918+
@test_throws ArgumentError d26939[nothing] = 1
919+
@test_throws ArgumentError get!(d26939, nothing, 1)
920+
@test_throws ArgumentError get!(() -> 1, d26939, nothing)
921+
@test isempty(d26939)
922+
@test length(d26939.ht) == 0
923+
@test length(d26939) == 0
898924

899925
# WeakKeyDict does not convert keys on setting
900926
@test_throws ArgumentError WeakKeyDict{Vector{Int},Any}([5.0]=>1)
901927
wkd = WeakKeyDict(A=>2)
902928
@test_throws ArgumentError get!(wkd, [2.0], 2)
903-
@test_throws ArgumentError get!(wkd, [1.0], 2) # get! fails even if the key is only
904-
# used for getting and not setting
929+
@test get!(wkd, [1.0], 2) === 2
905930

906931
# WeakKeyDict does convert on getting
907932
wkd = WeakKeyDict(A=>2)
@@ -913,16 +938,18 @@ Dict(1 => rand(2,3), 'c' => "asdf") # just make sure this does not trigger a dep
913938

914939
# map! on values of WKD
915940
wkd = WeakKeyDict(A=>2, B=>3)
916-
map!(v->v-1, values(wkd))
941+
map!(v -> v-1, values(wkd))
917942
@test wkd == WeakKeyDict(A=>1, B=>2)
918943

919944
# get!
920945
wkd = WeakKeyDict(A=>2)
921-
get!(wkd, B, 3)
946+
@test get!(wkd, B, 3) == 3
922947
@test wkd == WeakKeyDict(A=>2, B=>3)
923-
get!(()->4, wkd, C)
948+
@test get!(()->4, wkd, C) == 4
924949
@test wkd == WeakKeyDict(A=>2, B=>3, C=>4)
925-
@test_throws ArgumentError get!(()->5, wkd, [1.0])
950+
@test get!(()->5, wkd, [1.0]) == 2
951+
952+
GC.@preserve A B C D nothing
926953
end
927954

928955
@testset "issue #19995, hash of dicts" begin

0 commit comments

Comments
 (0)