Skip to content

Commit 1a0d37b

Browse files
Merge pull request #746 from AayushSabharwal/as/precompile-id-collision
fix: fix hashconsed ID collisions due to precompilation
2 parents ca6a4f2 + bff68cc commit 1a0d37b

File tree

3 files changed

+70
-50
lines changed

3 files changed

+70
-50
lines changed

src/cache.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Sentinel value used for a cache miss, since cached functions may return `nothing
55
"""
66
struct CacheSentinel end
77

8+
mutable struct IDType end
9+
810
"""
911
$(TYPEDEF)
1012
@@ -13,7 +15,7 @@ Struct wrapping the `id` of a `BasicSymbolic`, since arguments annotated
1315
up a symbolic or a `UInt`.
1416
"""
1517
struct SymbolicKey
16-
id::UInt64
18+
id::IDType
1719
end
1820

1921
"""
@@ -23,7 +25,17 @@ The key stored in the cache for a particular value. Returns a `SymbolicKey` for
2325
`BasicSymbolic` and is the identity function otherwise.
2426
"""
2527
# can't dispatch because `BasicSymbolic` isn't defined here
26-
get_cache_key(x) = x isa BasicSymbolic ? SymbolicKey(x.id[]) : x
28+
function get_cache_key(x)
29+
if x isa BasicSymbolic
30+
id = x.id[]
31+
if id === nothing
32+
return CacheSentinel()
33+
end
34+
return SymbolicKey(id)
35+
else
36+
x
37+
end
38+
end
2739

2840
"""
2941
associated_cache(fn)
@@ -241,6 +253,7 @@ macro cache(args...)
241253
argexprs = []
242254
# The name of the variable storing the result of looking up the cache
243255
cache_value_name = :val
256+
valid_key_condition = :(true)
244257
# The condition for a cache hit
245258
cache_hit_condition = :(!($cache_value_name isa $CacheSentinel))
246259

@@ -254,6 +267,7 @@ macro cache(args...)
254267
push!(keyexprs, :($get_cache_key($arg)))
255268
push!(argexprs, arg)
256269
push!(keytypes, Any)
270+
valid_key_condition = :($valid_key_condition && !(key[$(length(keyexprs))] isa $CacheSentinel))
257271
continue
258272
end
259273
argname, Texpr = arg.args
@@ -263,6 +277,7 @@ macro cache(args...)
263277
# if the type is `Any`, branch on it being a `BasicSymbolic`
264278
push!(keyexprs, :($get_cache_key($argname)))
265279
push!(keytypes, Any)
280+
valid_key_condition = :($valid_key_condition && !(key[$(length(keyexprs))] isa $CacheSentinel))
266281
continue
267282
end
268283

@@ -275,6 +290,7 @@ macro cache(args...)
275290
push!(keytypes, Union{keyTs...})
276291
if maybe_basicsymbolic
277292
push!(keyexprs, :($get_cache_key($argname)))
293+
valid_key_condition = :($valid_key_condition && !(key[$(length(keyexprs))] isa $CacheSentinel))
278294
else
279295
push!(keyexprs, argname)
280296
end
@@ -286,6 +302,7 @@ macro cache(args...)
286302
if T <: BasicSymbolic
287303
push!(keytypes, SymbolicKey)
288304
push!(keyexprs, :($get_cache_key($argname)))
305+
valid_key_condition = :($valid_key_condition && !(key[$(length(keyexprs))] isa $CacheSentinel))
289306
else
290307
push!(keytypes, T)
291308
push!(keyexprs, argname)
@@ -345,6 +362,9 @@ macro cache(args...)
345362
if $conditions
346363
# construct the `Tuple` key
347364
key = $keyexpr
365+
if !($valid_key_condition)
366+
return $innercall
367+
end
348368
# get the cache and stats from the `TaskLocalValue` to avoid accessing
349369
# `task_local_storage` repeatedly.
350370
cachedict, cachestats = $cachename.tlv[]

src/types.jl

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@ const EMPTY_HASH = RefValue(UInt(0))
2222
const EMPTY_DICT = sdict()
2323
const EMPTY_DICT_T = typeof(EMPTY_DICT)
2424
const ENABLE_HASHCONSING = Ref(true)
25+
const TID = Union{IDType, Nothing}
26+
const DID = nothing
2527

2628
@compactify show_methods=false begin
2729
@abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
2830
metadata::Metadata = NO_METADATA
29-
id::RefValue{UInt64} = Ref{UInt64}(0)
31+
id::RefValue{TID} = Ref{TID}(DID)
3032
end
3133
mutable struct Sym{T} <: BasicSymbolic{T}
3234
name::Symbol = :OOF
@@ -88,8 +90,6 @@ function exprtype(x::BasicSymbolic)
8890
end
8991
end
9092

91-
const wvd = TaskLocalValue{WeakValueDict{UInt, BasicSymbolic}}(WeakValueDict{UInt, BasicSymbolic})
92-
9393
# Same but different error messages
9494
@noinline error_on_type() = error("Internal error: unreachable reached!")
9595
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
@@ -108,11 +108,11 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple
108108
# Call outer constructor because hash consing cannot be applied in inner constructor
109109
@compactified obj::BasicSymbolic begin
110110
Sym => Sym{T}(nt_new.name; nt_new...)
111-
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
112-
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
113-
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
114-
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
115-
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
111+
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{TID}(DID))
112+
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{TID}(DID))
113+
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{TID}(DID))
114+
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{TID}(DID))
115+
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{TID}(DID))
116116
_ => Unityper.rt_constructor(obj){T}(;nt_new...)
117117
end
118118
end
@@ -516,11 +516,11 @@ end
516516
### Constructors
517517
###
518518

519-
mutable struct AtomicIDCounter
520-
@atomic x::UInt64
521-
end
519+
const wvd = TaskLocalValue{WeakValueDict{UInt, BasicSymbolic}}(WeakValueDict{UInt, BasicSymbolic})
522520

523-
const ID_COUNTER = AtomicIDCounter(0)
521+
function generate_id()
522+
return IDType()
523+
end
524524

525525
"""
526526
$(TYPEDSIGNATURES)
@@ -552,27 +552,27 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
552552
h = hash2(s)
553553
k = get!(cache, h, s)
554554
if isequal_with_metadata(k, s)
555-
if iszero(k.id[])
556-
k.id[] = @atomic ID_COUNTER.x += 1
555+
if isnothing(k.id[])
556+
k.id[] = generate_id()
557557
end
558558
return k
559559
else
560-
if iszero(s.id[])
561-
s.id[] = @atomic ID_COUNTER.x += 1
560+
if isnothing(s.id[])
561+
s.id[] = generate_id()
562562
end
563563
return s
564564
end
565565
end
566566

567567
function Sym{T}(name::Symbol; kw...) where {T}
568-
s = Sym{T}(; name, kw..., id = Ref{UInt}(0))
568+
s = Sym{T}(; name, kw..., id = Ref{TID}(DID))
569569
BasicSymbolic(s)
570570
end
571571

572572
function Term{T}(f, args; kw...) where T
573573
args = SmallV{Any}(args)
574574

575-
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw..., id = Ref{UInt64}(0))
575+
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw..., id = Ref{TID}(DID))
576576
BasicSymbolic(s)
577577
end
578578

@@ -602,7 +602,7 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
602602
end
603603
end
604604

605-
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw..., id = Ref{UInt64}(0))
605+
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw..., id = Ref{TID}(DID))
606606
BasicSymbolic(s)
607607
end
608608

@@ -620,7 +620,7 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
620620
else
621621
coeff = a
622622
dict = b
623-
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw..., id = Ref{UInt64}(0))
623+
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw..., id = Ref{TID}(DID))
624624
BasicSymbolic(s)
625625
end
626626
end
@@ -688,7 +688,7 @@ function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
688688
end
689689
end
690690

691-
s = Div{T}(; num=n, den=d, simplified, arguments=SmallV{Any}(), metadata, id = Ref{UInt64}(0))
691+
s = Div{T}(; num=n, den=d, simplified, arguments=SmallV{Any}(), metadata, id = Ref{TID}(DID))
692692
BasicSymbolic(s)
693693
end
694694

@@ -708,7 +708,7 @@ function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
708708
b = unwrap(b)
709709
_iszero(b) && return 1
710710
_isone(b) && return a
711-
s = Pow{T}(; base=a, exp=b, arguments=SmallV{Any}(), metadata, id = Ref{UInt64}(0))
711+
s = Pow{T}(; base=a, exp=b, arguments=SmallV{Any}(), metadata, id = Ref{TID}(DID))
712712
BasicSymbolic(s)
713713
end
714714

test/cache_macro.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -70,32 +70,32 @@ end
7070
return 2x + 1
7171
end
7272

73-
@testset "::Union (with `UInt`)" begin
74-
@syms x
75-
val = f2(x)
76-
@test isequal(val, 2x + 1)
77-
cachestruct = associated_cache(f2)
78-
cache, stats = cachestruct.tlv[]
79-
@test cache isa Dict{Tuple{Union{SymbolicKey, UInt}}, Union{BasicSymbolic, UInt}}
80-
@test length(cache) == 1
81-
@test cache[(get_cache_key(x),)] === val
82-
@test stats.hits == 0
83-
@test stats.misses == 1
84-
f2(x)
85-
@test stats.hits == 1
86-
@test stats.misses == 1
87-
88-
y = get_cache_key(x).id
89-
val = f2(y)
90-
@test val == 2y + 1
91-
@test length(cache) == 2
92-
@test cache[(y,)] == val
93-
@test stats.misses == 2
94-
95-
clear_cache!(f2)
96-
@test length(cache) == 0
97-
@test stats.hits == stats.misses == stats.clears == 0
98-
end
73+
# @testset "::Union (with `UInt`)" begin
74+
# @syms x
75+
# val = f2(x)
76+
# @test isequal(val, 2x + 1)
77+
# cachestruct = associated_cache(f2)
78+
# cache, stats = cachestruct.tlv[]
79+
# @test cache isa Dict{Tuple{Union{SymbolicKey, UInt}}, Union{BasicSymbolic, UInt}}
80+
# @test length(cache) == 1
81+
# @test cache[(get_cache_key(x),)] === val
82+
# @test stats.hits == 0
83+
# @test stats.misses == 1
84+
# f2(x)
85+
# @test stats.hits == 1
86+
# @test stats.misses == 1
87+
88+
# y = get_cache_key(x).id
89+
# val = f2(y)
90+
# @test val == 2y + 1
91+
# @test length(cache) == 2
92+
# @test cache[(y,)] == val
93+
# @test stats.misses == 2
94+
95+
# clear_cache!(f2)
96+
# @test length(cache) == 0
97+
# @test stats.hits == stats.misses == stats.clears == 0
98+
# end
9999

100100
@cache function f3(x)::Union{BasicSymbolic, Int}
101101
return 2x + 1

0 commit comments

Comments
 (0)