Skip to content

Commit 9d52810

Browse files
feat: revert usage of WeakCacheSet for hashconsing
1 parent fda0d2e commit 9d52810

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
2727
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
2828
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2929
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
30+
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"
3031

3132
[weakdeps]
3233
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
@@ -61,6 +62,7 @@ TaskLocalValues = "0.1.2"
6162
TermInterface = "2.0"
6263
TimerOutputs = "0.5"
6364
Unityper = "0.1.2"
65+
WeakValueDicts = "0.1.0"
6466
julia = "1.10"
6567

6668
[extras]

src/SymbolicUtils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ import TermInterface: iscall, isexpr, head, children,
2222
import ArrayInterface
2323
import ExproniconLite as EL
2424
import TaskLocalValues: TaskLocalValue
25+
using WeakValueDicts: WeakValueDict
2526

26-
include("WeakCacheSets.jl")
27+
# include("WeakCacheSets.jl")
2728

2829
include("cache.jl")
2930
Base.@deprecate istree iscall

src/types.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ function exprtype(x::BasicSymbolic)
8888
end
8989
end
9090

91-
const wcs = TaskLocalValue{WeakCacheSet{BasicSymbolic}}(WeakCacheSet{BasicSymbolic})
91+
const wvd = TaskLocalValue{WeakValueDict{UInt, BasicSymbolic}}(WeakValueDict{UInt, BasicSymbolic})
9292

9393
# Same but different error messages
9494
@noinline error_on_type() = error("Internal error: unreachable reached!")
@@ -547,12 +547,21 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
547547
if !ENABLE_HASHCONSING[]
548548
return s
549549
end
550-
cache = wcs[]
551-
k = getkey!(cache, s)
552-
if iszero(k.id[])
553-
k.id[] = @atomic ID_COUNTER.x += 1
550+
551+
cache = wvd[]
552+
h = hash2(s)
553+
k = get!(cache, h, s)
554+
if isequal_with_metadata(k, s)
555+
if iszero(k.id[])
556+
k.id[] = @atomic ID_COUNTER.x += 1
557+
end
558+
return k
559+
else
560+
if iszero(s.id[])
561+
s.id[] = @atomic ID_COUNTER.x += 1
562+
end
563+
return s
554564
end
555-
return k
556565
end
557566

558567
function Sym{T}(name::Symbol; kw...) where {T}

0 commit comments

Comments
 (0)