Skip to content

Commit 2f8d4fa

Browse files
Merge pull request #716 from AayushSabharwal/as/uint-id
feat: use a global atomic `UInt64` to give each `BasicSymbolic` a unique ID
2 parents 96f819a + a4dabb3 commit 2f8d4fa

File tree

3 files changed

+54
-84
lines changed

3 files changed

+54
-84
lines changed

src/cache.jl

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,23 @@ struct CacheSentinel end
88
"""
99
$(TYPEDEF)
1010
11-
Struct wrapping the `objectid` of a `BasicSymbolic`, since arguments annotated
11+
Struct wrapping the `id` of a `BasicSymbolic`, since arguments annotated
1212
`::Union{BasicSymbolic, UInt}` would not be able to differentiate between looking
1313
up a symbolic or a `UInt`.
1414
"""
1515
struct SymbolicKey
16-
id::UInt
16+
id::UInt64
1717
end
1818

19+
"""
20+
$(TYPEDSIGNATURES)
21+
22+
The key stored in the cache for a particular value. Returns a `SymbolicKey` for
23+
`BasicSymbolic` and is the identity function otherwise.
24+
"""
25+
# can't dispatch because `BasicSymbolic` isn't defined here
26+
get_cache_key(x) = x isa BasicSymbolic ? SymbolicKey(x.id[]) : x
27+
1928
"""
2029
associated_cache(fn)
2130
@@ -234,10 +243,6 @@ macro cache(args...)
234243
cache_value_name = :val
235244
# The condition for a cache hit
236245
cache_hit_condition = :(!($cache_value_name isa $CacheSentinel))
237-
# Type of additional data stored with cached result. Used to compare
238-
# equality of `BasicSymbolic` arguments, since `objectid` is a hash.
239-
cache_additional_types = []
240-
cache_additional_values = []
241246

242247
for arg in fn.args
243248
# handle arguments with defaults
@@ -246,24 +251,18 @@ macro cache(args...)
246251
end
247252
if !Meta.isexpr(arg, :(::))
248253
# if the type is `Any`, branch on it being a `BasicSymbolic`
249-
push!(keyexprs, :($arg isa BasicSymbolic ? $SymbolicKey(objectid($arg)) : $arg))
254+
push!(keyexprs, :($get_cache_key($arg)))
250255
push!(argexprs, arg)
251256
push!(keytypes, Any)
252-
push!(cache_additional_types, Any)
253-
push!(cache_additional_values, arg)
254-
cache_hit_condition = :($cache_hit_condition && (!($arg isa BasicSymbolic) || $arg === $cache_value_name[$(length(cache_additional_values))]))
255257
continue
256258
end
257259
argname, Texpr = arg.args
258260
push!(argexprs, argname)
259261

260262
if Texpr == :Any
261263
# if the type is `Any`, branch on it being a `BasicSymbolic`
262-
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
264+
push!(keyexprs, :($get_cache_key($argname)))
263265
push!(keytypes, Any)
264-
push!(cache_additional_types, Any)
265-
push!(cache_additional_values, argname)
266-
cache_hit_condition = :($cache_hit_condition && (!($argname isa BasicSymbolic) || $argname === $cache_value_name[$(length(cache_additional_values))]))
267266
continue
268267
end
269268

@@ -275,10 +274,7 @@ macro cache(args...)
275274
maybe_basicsymbolic = any(x -> x <: BasicSymbolic, Ts)
276275
push!(keytypes, Union{keyTs...})
277276
if maybe_basicsymbolic
278-
push!(keyexprs, :($argname isa BasicSymbolic ? $SymbolicKey(objectid($argname)) : $argname))
279-
push!(cache_additional_types, Texpr)
280-
push!(cache_additional_values, argname)
281-
cache_hit_condition = :($cache_hit_condition && (!($argname isa BasicSymbolic) || $argname === $cache_value_name[$(length(cache_additional_values))]))
277+
push!(keyexprs, :($get_cache_key($argname)))
282278
else
283279
push!(keyexprs, argname)
284280
end
@@ -289,10 +285,7 @@ macro cache(args...)
289285
T = Base.eval(__module__, Texpr)
290286
if T <: BasicSymbolic
291287
push!(keytypes, SymbolicKey)
292-
push!(keyexprs, :($SymbolicKey(objectid($argname))))
293-
push!(cache_additional_types, T)
294-
push!(cache_additional_values, argname)
295-
cache_hit_condition = :($cache_hit_condition && $argname === $cache_value_name[$(length(cache_additional_values))])
288+
push!(keyexprs, :($get_cache_key($argname)))
296289
else
297290
push!(keytypes, T)
298291
push!(keyexprs, argname)
@@ -313,9 +306,7 @@ macro cache(args...)
313306
# construct an expression for the type of the cache keys
314307
keyT = Expr(:curly, Tuple)
315308
append!(keyT.args, keytypes)
316-
valT = Expr(:curly, Tuple)
317-
append!(valT.args, cache_additional_types)
318-
push!(valT.args, rettype)
309+
valT = rettype
319310
# the type of the cache
320311
cacheT = :(Dict{$keyT, $valT})
321312
# type of the `TaskLocalValue`
@@ -364,7 +355,7 @@ macro cache(args...)
364355
if $cache_hit_condition
365356
# cache hit
366357
cachestats.hits += 1
367-
return $cache_value_name[end]
358+
return $cache_value_name
368359
end
369360
# cache miss
370361
cachestats.misses += 1
@@ -375,8 +366,8 @@ macro cache(args...)
375366
$(filter!)($cachename, cachedict)
376367
end
377368
# add to cache
378-
cachedict[key] = ($(cache_additional_values...), val)
379-
return val
369+
cachedict[key] = $cache_value_name
370+
return $cache_value_name
380371
end
381372

382373
# if we're not doing caching

src/types.jl

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ const ENABLE_HASHCONSING = Ref(true)
2626
@compactify show_methods=false begin
2727
@abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
2828
metadata::Metadata = NO_METADATA
29+
id::RefValue{UInt64} = Ref{UInt64}(0)
2930
end
3031
mutable struct Sym{T} <: BasicSymbolic{T}
3132
name::Symbol = :OOF
@@ -107,11 +108,11 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple
107108
# Call outer constructor because hash consing cannot be applied in inner constructor
108109
@compactified obj::BasicSymbolic begin
109110
Sym => Sym{T}(nt_new.name; nt_new...)
110-
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
111-
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
112-
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
113-
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
114-
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)))
111+
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
112+
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
113+
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
114+
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
115+
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new..., hash = RefValue(UInt(0)), hash2 = RefValue(UInt(0)), id = Ref{UInt64}(0))
115116
_ => Unityper.rt_constructor(obj){T}(;nt_new...)
116117
end
117118
end
@@ -255,6 +256,7 @@ end
255256

256257
function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
257258
a === b && return true
259+
a.id == b.id && a.id != 0 && return true
258260

259261
E = exprtype(a)
260262
E === exprtype(b) || return false
@@ -298,6 +300,7 @@ function.
298300
"""
299301
function isequal_with_metadata(a::BasicSymbolic{T}, b::BasicSymbolic{S})::Bool where {T, S}
300302
a === b && return true
303+
a.id == b.id && a.id != 0 && return true
301304

302305
E = exprtype(a)
303306
E === exprtype(b) || return false
@@ -513,6 +516,12 @@ end
513516
### Constructors
514517
###
515518

519+
mutable struct AtomicIDCounter
520+
@atomic x::UInt64
521+
end
522+
523+
const ID_COUNTER = AtomicIDCounter(0)
524+
516525
"""
517526
$(TYPEDSIGNATURES)
518527
@@ -542,21 +551,27 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
542551
h = hash2(s)
543552
k = get!(cache, h, s)
544553
if isequal_with_metadata(k, s)
554+
if iszero(k.id[])
555+
k.id[] = @atomic ID_COUNTER.x += 1
556+
end
545557
return k
546558
else
559+
if iszero(s.id[])
560+
s.id[] = @atomic ID_COUNTER.x += 1
561+
end
547562
return s
548563
end
549564
end
550565

551566
function Sym{T}(name::Symbol; kw...) where {T}
552-
s = Sym{T}(; name, kw...)
567+
s = Sym{T}(; name, kw..., id = Ref{UInt}(0))
553568
BasicSymbolic(s)
554569
end
555570

556571
function Term{T}(f, args; kw...) where T
557572
args = SmallV{Any}(args)
558573

559-
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw...)
574+
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), kw..., id = Ref{UInt64}(0))
560575
BasicSymbolic(s)
561576
end
562577

@@ -586,7 +601,7 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
586601
end
587602
end
588603

589-
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw...)
604+
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw..., id = Ref{UInt64}(0))
590605
BasicSymbolic(s)
591606
end
592607

@@ -604,7 +619,7 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
604619
else
605620
coeff = a
606621
dict = b
607-
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw...)
622+
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), hash2=Ref(UInt(0)), metadata, arguments=SmallV{Any}(), kw..., id = Ref{UInt64}(0))
608623
BasicSymbolic(s)
609624
end
610625
end
@@ -672,7 +687,7 @@ function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
672687
end
673688
end
674689

675-
s = Div{T}(; num=n, den=d, simplified, arguments=SmallV{Any}(), metadata)
690+
s = Div{T}(; num=n, den=d, simplified, arguments=SmallV{Any}(), metadata, id = Ref{UInt64}(0))
676691
BasicSymbolic(s)
677692
end
678693

@@ -692,7 +707,7 @@ function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
692707
b = unwrap(b)
693708
_iszero(b) && return 1
694709
_isone(b) && return a
695-
s = Pow{T}(; base=a, exp=b, arguments=SmallV{Any}(), metadata)
710+
s = Pow{T}(; base=a, exp=b, arguments=SmallV{Any}(), metadata, id = Ref{UInt64}(0))
696711
BasicSymbolic(s)
697712
end
698713

test/cache_macro.jl

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using SymbolicUtils
22
using SymbolicUtils: BasicSymbolic, @cache, associated_cache, set_limit!, get_limit,
3-
clear_cache!, SymbolicKey, metadata, maketerm
3+
clear_cache!, SymbolicKey, metadata, maketerm, get_cache_key
44
using OhMyThreads: tmap
55
using Random
66

@@ -14,9 +14,9 @@ end
1414
@test isequal(val, 2x + 1)
1515
cachestruct = associated_cache(f1)
1616
cache, stats = cachestruct.tlv[]
17-
@test cache isa Dict{Tuple{SymbolicKey}, Tuple{BasicSymbolic, BasicSymbolic}}
17+
@test cache isa Dict{Tuple{SymbolicKey}, BasicSymbolic}
1818
@test length(cache) == 1
19-
@test cache[(SymbolicKey(objectid(x)),)][end] === val
19+
@test cache[(get_cache_key(x),)] === val
2020
@test stats.hits == 0
2121
@test stats.misses == 1
2222
f1(x)
@@ -76,20 +76,20 @@ end
7676
@test isequal(val, 2x + 1)
7777
cachestruct = associated_cache(f2)
7878
cache, stats = cachestruct.tlv[]
79-
@test cache isa Dict{Tuple{Union{SymbolicKey, UInt}}, NTuple{2, Union{BasicSymbolic, UInt}}}
79+
@test cache isa Dict{Tuple{Union{SymbolicKey, UInt}}, Union{BasicSymbolic, UInt}}
8080
@test length(cache) == 1
81-
@test cache[(SymbolicKey(objectid(x)),)][end] === val
81+
@test cache[(get_cache_key(x),)] === val
8282
@test stats.hits == 0
8383
@test stats.misses == 1
8484
f2(x)
8585
@test stats.hits == 1
8686
@test stats.misses == 1
8787

88-
y = objectid(x)
88+
y = get_cache_key(x).id
8989
val = f2(y)
9090
@test val == 2y + 1
9191
@test length(cache) == 2
92-
@test cache[(y,)][end] == val
92+
@test cache[(y,)] == val
9393
@test stats.misses == 2
9494

9595
clear_cache!(f2)
@@ -111,9 +111,9 @@ end
111111
@test isequal(val, 2x + 1)
112112
cachestruct = associated_cache(fn)
113113
cache, stats = cachestruct.tlv[]
114-
@test cache isa Dict{Tuple{Any}, Tuple{Any, Union{BasicSymbolic, Int}}}
114+
@test cache isa Dict{Tuple{Any}, Union{BasicSymbolic, Int}}
115115
@test length(cache) == 1
116-
@test cache[(SymbolicKey(objectid(x)),)][end] === val
116+
@test cache[(get_cache_key(x),)] === val
117117
@test stats.hits == 0
118118
@test stats.misses == 1
119119
fn(x)
@@ -161,42 +161,6 @@ end
161161
@test isequal(result, truevals)
162162
end
163163

164-
@cache function f5(x::BasicSymbolic, y::Union{BasicSymbolic, Int}, z)::BasicSymbolic
165-
return x + y + z
166-
end
167-
168-
# temporary definition to induce objectid collisions
169-
Base.objectid(x::BasicSymbolic) = 0x42
170-
171-
@testset "`objectid` collision handling" begin
172-
@syms x y z
173-
@test objectid(x) == objectid(y) == objectid(z) == 0x42
174-
cachestruct = associated_cache(f5)
175-
cache, stats = cachestruct.tlv[]
176-
val = f5(x, 1, 2)
177-
@test isequal(val, x + 3)
178-
@test length(cache) == 1
179-
@test stats.misses == 1
180-
val2 = f5(y, 1, 2)
181-
@test isequal(val2, y + 3)
182-
@test length(cache) == 1
183-
@test stats.misses == 2
184-
185-
clear_cache!(f5)
186-
val = f5(x, y, z)
187-
@test isequal(val, x + y + z)
188-
@test length(cache) == 1
189-
@test stats.misses == 1
190-
val2 = f5(y, 2z, x)
191-
@test isequal(val2, x + y + 2z)
192-
@test length(cache) == 1
193-
@test stats.misses == 2
194-
end
195-
196-
Base.delete_method(only(methods(objectid, @__MODULE__)))
197-
@syms x
198-
@test objectid(x) != 0x42
199-
200164
@cache limit = 10 retain_fraction = 0.1 function f6(x::BasicSymbolic, y::Union{BasicSymbolic, Int}, z)::BasicSymbolic
201165
return x + y + z
202166
end

0 commit comments

Comments
 (0)