Skip to content

Commit e0babe8

Browse files
authored
fix return type of get! on IdDict (#36383)
1 parent 76a2e36 commit e0babe8

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

base/iddict.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ function get(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {
8585
val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, default)
8686
val === default ? default : val::V
8787
end
88+
8889
function getindex(d::IdDict{K,V}, @nospecialize(key)) where {K, V}
89-
val = get(d, key, secret_table_token)
90+
val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, secret_table_token)
9091
val === secret_table_token && throw(KeyError(key))
9192
return val::V
9293
end
@@ -134,23 +135,38 @@ length(d::IdDict) = d.count
134135

135136
copy(d::IdDict) = typeof(d)(d)
136137

137-
get!(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V} = (d[key] = get(d, key, default))::V
138+
function get!(d::IdDict{K,V}, @nospecialize(key), @nospecialize(default)) where {K, V}
139+
val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, secret_table_token)
140+
if val === secret_table_token
141+
val = isa(default, V) ? default : convert(V, default)
142+
setindex!(d, val, key)
143+
return val
144+
else
145+
return val::V
146+
end
147+
end
138148

139149
function get(default::Callable, d::IdDict{K,V}, @nospecialize(key)) where {K, V}
140-
val = get(d, key, secret_table_token)
150+
val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, secret_table_token)
141151
if val === secret_table_token
142-
val = default()
152+
return default()
153+
else
154+
return val::V
143155
end
144-
return val
145156
end
146157

147158
function get!(default::Callable, d::IdDict{K,V}, @nospecialize(key)) where {K, V}
148-
val = get(d, key, secret_table_token)
159+
val = ccall(:jl_eqtable_get, Any, (Any, Any, Any), d.ht, key, secret_table_token)
149160
if val === secret_table_token
150161
val = default()
162+
if !isa(val, V)
163+
val = convert(V, val)
164+
end
151165
setindex!(d, val, key)
166+
return val
167+
else
168+
return val::V
152169
end
153-
return val
154170
end
155171

156172
in(@nospecialize(k), v::KeySet{<:Any,<:IdDict}) = get(v.dict, k, secret_table_token) !== secret_table_token

test/dict.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,8 @@ end
554554
@test delete!(d, "a") === d
555555
@test !haskey(d, "a")
556556
@test_throws ArgumentError get!(IdDict{Symbol,Any}(), 2, "b")
557-
557+
@test get!(IdDict{Int,Int}(), 1, 2.0) === 2
558+
@test get!(()->2.0, IdDict{Int,Int}(), 1) === 2
558559

559560
# sizehint! & rehash!
560561
d = IdDict()

0 commit comments

Comments
 (0)