Skip to content

Commit a587847

Browse files
Merge pull request #658 from JuliaSymbolics/hash-consing
Implement hash consing for `Sym`
2 parents f9b0ade + 13b642b commit a587847

File tree

6 files changed

+123
-11
lines changed

6 files changed

+123
-11
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
2626
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
2727
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2828
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
29+
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"
2930

3031
[weakdeps]
3132
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
@@ -57,6 +58,7 @@ SymbolicIndexingInterface = "0.3"
5758
TermInterface = "2.0"
5859
TimerOutputs = "0.5"
5960
Unityper = "0.1.2"
61+
WeakValueDicts = "0.1.0"
6062
julia = "1.3"
6163

6264
[extras]

src/SymbolicUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import TermInterface: iscall, isexpr, head, children,
2020
operation, arguments, metadata, maketerm, sorted_arguments
2121
# For ReverseDiffExt
2222
import ArrayInterface
23+
using WeakValueDicts: WeakValueDict
2324

2425
Base.@deprecate istree iscall
2526
export istree, operation, arguments, sorted_arguments, iscall

src/types.jl

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,38 +23,38 @@ const EMPTY_DICT = sdict()
2323
const EMPTY_DICT_T = typeof(EMPTY_DICT)
2424

2525
@compactify show_methods=false begin
26-
@abstract struct BasicSymbolic{T} <: Symbolic{T}
26+
@abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
2727
metadata::Metadata = NO_METADATA
2828
end
29-
struct Sym{T} <: BasicSymbolic{T}
29+
mutable struct Sym{T} <: BasicSymbolic{T}
3030
name::Symbol = :OOF
3131
end
32-
struct Term{T} <: BasicSymbolic{T}
32+
mutable struct Term{T} <: BasicSymbolic{T}
3333
f::Any = identity # base/num if Pow; issorted if Add/Dict
3434
arguments::Vector{Any} = EMPTY_ARGS
3535
hash::RefValue{UInt} = EMPTY_HASH
3636
end
37-
struct Mul{T} <: BasicSymbolic{T}
37+
mutable struct Mul{T} <: BasicSymbolic{T}
3838
coeff::Any = 0 # exp/den if Pow
3939
dict::EMPTY_DICT_T = EMPTY_DICT
4040
hash::RefValue{UInt} = EMPTY_HASH
4141
arguments::Vector{Any} = EMPTY_ARGS
4242
issorted::RefValue{Bool} = NOT_SORTED
4343
end
44-
struct Add{T} <: BasicSymbolic{T}
44+
mutable struct Add{T} <: BasicSymbolic{T}
4545
coeff::Any = 0 # exp/den if Pow
4646
dict::EMPTY_DICT_T = EMPTY_DICT
4747
hash::RefValue{UInt} = EMPTY_HASH
4848
arguments::Vector{Any} = EMPTY_ARGS
4949
issorted::RefValue{Bool} = NOT_SORTED
5050
end
51-
struct Div{T} <: BasicSymbolic{T}
51+
mutable struct Div{T} <: BasicSymbolic{T}
5252
num::Any = 1
5353
den::Any = 1
5454
simplified::Bool = false
5555
arguments::Vector{Any} = EMPTY_ARGS
5656
end
57-
struct Pow{T} <: BasicSymbolic{T}
57+
mutable struct Pow{T} <: BasicSymbolic{T}
5858
base::Any = 1
5959
exp::Any = 1
6060
arguments::Vector{Any} = EMPTY_ARGS
@@ -77,6 +77,8 @@ function exprtype(x::BasicSymbolic)
7777
end
7878
end
7979

80+
const wvd = WeakValueDict{UInt, BasicSymbolic}()
81+
8082
# Same but different error messages
8183
@noinline error_on_type() = error("Internal error: unreachable reached!")
8284
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
@@ -92,7 +94,11 @@ const SIMPLIFIED = 0x01 << 0
9294
function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where T
9395
nt = getproperties(obj)
9496
nt_new = merge(nt, patch)
95-
Unityper.rt_constructor(obj){T}(;nt_new...)
97+
# Call outer constructor because hash consing cannot be applied in inner constructor
98+
@compactified obj::BasicSymbolic begin
99+
Sym => Sym{T}(nt_new.name; nt_new...)
100+
_ => Unityper.rt_constructor(obj){T}(;nt_new...)
101+
end
96102
end
97103

98104
###
@@ -265,6 +271,26 @@ function _isequal(a, b, E)
265271
end
266272
end
267273

274+
"""
275+
$(TYPEDSIGNATURES)
276+
277+
Checks for equality between two `BasicSymbolic` objects, considering both their
278+
values and metadata.
279+
280+
The default `Base.isequal` function for `BasicSymbolic` only compares their expressions
281+
and ignores metadata. This does not help deal with hash collisions when metadata is
282+
relevant for distinguishing expressions, particularly in hashing contexts. This function
283+
provides a stricter equality check that includes metadata comparison, preventing
284+
such collisions.
285+
286+
Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` and
287+
downstream packages like `ModelingToolkit.jl`, hence the need for this separate
288+
function.
289+
"""
290+
function isequal_with_metadata(a::BasicSymbolic, b::BasicSymbolic)::Bool
291+
isequal(a, b) && isequal(metadata(a), metadata(b))
292+
end
293+
268294
Base.one( s::Symbolic) = one( symtype(s))
269295
Base.zero(s::Symbolic) = zero(symtype(s))
270296

@@ -307,12 +333,61 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
307333
end
308334
end
309335

336+
"""
337+
$(TYPEDSIGNATURES)
338+
339+
Calculates a hash value for a `BasicSymbolic` object, incorporating both its metadata and
340+
symtype.
341+
342+
This function provides an alternative hashing strategy to `Base.hash` for `BasicSymbolic`
343+
objects. Unlike `Base.hash`, which only considers the expression structure, `hash2` also
344+
includes the metadata and symtype in the hash calculation. This can be beneficial for hash
345+
consing, allowing for more effective deduplication of symbolically equivalent expressions
346+
with different metadata or symtypes.
347+
"""
348+
hash2(s::BasicSymbolic) = hash2(s, zero(UInt))
349+
function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
350+
hash(metadata(s), hash(T, hash(s, salt)))
351+
end
352+
310353
###
311354
### Constructors
312355
###
313356

314-
function Sym{T}(name::Symbol; kw...) where T
315-
Sym{T}(; name=name, kw...)
357+
"""
358+
$(TYPEDSIGNATURES)
359+
360+
Implements hash consing (flyweight design pattern) for `BasicSymbolic` objects.
361+
362+
This function checks if an equivalent `BasicSymbolic` object already exists. It uses a
363+
custom hash function (`hash2`) incorporating metadata and symtypes to search for existing
364+
objects in a `WeakValueDict` (`wvd`). Due to the possibility of hash collisions (where
365+
different objects produce the same hash), a custom equality check (`isequal_with_metadata`)
366+
which includes metadata comparison, is used to confirm the equivalence of objects with
367+
matching hashes. If an equivalent object is found, the existing object is returned;
368+
otherwise, the input `s` is returned. This reduces memory usage, improves compilation time
369+
for runtime code generation, and supports built-in common subexpression elimination,
370+
particularly when working with symbolic objects with metadata.
371+
372+
Using a `WeakValueDict` ensures that only weak references to `BasicSymbolic` objects are
373+
stored, allowing objects that are no longer strongly referenced to be garbage collected.
374+
Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.hash` and
375+
`Base.isequal` to accommodate metadata without disrupting existing tests reliant on the
376+
original behavior of those functions.
377+
"""
378+
function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
379+
h = hash2(s)
380+
t = get!(wvd, h, s)
381+
if t === s || isequal_with_metadata(t, s)
382+
return t
383+
else
384+
return s
385+
end
386+
end
387+
388+
function Sym{T}(name::Symbol; kw...) where {T}
389+
s = Sym{T}(; name, kw...)
390+
BasicSymbolic(s)
316391
end
317392

318393
function Term{T}(f, args; kw...) where T

test/basics.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term
1+
using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term, isequal_with_metadata
22
using SymbolicUtils
33
using IfElse: ifelse
44
using Setfield
@@ -336,6 +336,13 @@ end
336336

337337
@test !isequal(a, missing)
338338
@test !isequal(missing, b)
339+
340+
a1 = setmetadata(a, Ctx1, "meta_1")
341+
a2 = setmetadata(a, Ctx1, "meta_1")
342+
a3 = setmetadata(a, Ctx2, "meta_2")
343+
@test !isequal_with_metadata(a, a1)
344+
@test isequal_with_metadata(a1, a2)
345+
@test !isequal_with_metadata(a1, a3)
339346
end
340347

341348
@testset "subtyping" begin

test/hash_consing.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using SymbolicUtils, Test
2+
3+
struct Ctx1 end
4+
struct Ctx2 end
5+
6+
@testset "Sym" begin
7+
x1 = only(@syms x)
8+
x2 = only(@syms x)
9+
@test x1 === x2
10+
x3 = only(@syms x::Float64)
11+
@test x1 !== x3
12+
x4 = only(@syms x::Float64)
13+
@test x1 !== x4
14+
@test x3 === x4
15+
x5 = only(@syms x::Int)
16+
x6 = only(@syms x::Int)
17+
@test x1 !== x5
18+
@test x3 !== x5
19+
@test x5 === x6
20+
21+
xm1 = setmetadata(x1, Ctx1, "meta_1")
22+
xm2 = setmetadata(x1, Ctx1, "meta_1")
23+
@test xm1 === xm2
24+
xm3 = setmetadata(x1, Ctx2, "meta_2")
25+
@test xm1 !== xm3
26+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ using Pkg, Test, SafeTestsets
1616
# Disabled until https://github.com/JuliaMath/SpecialFunctions.jl/issues/446 is fixed
1717
@safetestset "Fuzz" begin include("fuzz.jl") end
1818
@safetestset "Adjoints" begin include("adjoints.jl") end
19+
@safetestset "Hash Consing" begin include("hash_consing.jl") end
1920
end
2021
end

0 commit comments

Comments
 (0)