@@ -23,38 +23,38 @@ const EMPTY_DICT = sdict()
23
23
const EMPTY_DICT_T = typeof (EMPTY_DICT)
24
24
25
25
@compactify show_methods= false begin
26
- @abstract struct BasicSymbolic{T} <: Symbolic{T}
26
+ @abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
27
27
metadata:: Metadata = NO_METADATA
28
28
end
29
- struct Sym{T} <: BasicSymbolic{T}
29
+ mutable struct Sym{T} <: BasicSymbolic{T}
30
30
name:: Symbol = :OOF
31
31
end
32
- struct Term{T} <: BasicSymbolic{T}
32
+ mutable struct Term{T} <: BasicSymbolic{T}
33
33
f:: Any = identity # base/num if Pow; issorted if Add/Dict
34
34
arguments:: Vector{Any} = EMPTY_ARGS
35
35
hash:: RefValue{UInt} = EMPTY_HASH
36
36
end
37
- struct Mul{T} <: BasicSymbolic{T}
37
+ mutable struct Mul{T} <: BasicSymbolic{T}
38
38
coeff:: Any = 0 # exp/den if Pow
39
39
dict:: EMPTY_DICT_T = EMPTY_DICT
40
40
hash:: RefValue{UInt} = EMPTY_HASH
41
41
arguments:: Vector{Any} = EMPTY_ARGS
42
42
issorted:: RefValue{Bool} = NOT_SORTED
43
43
end
44
- struct Add{T} <: BasicSymbolic{T}
44
+ mutable struct Add{T} <: BasicSymbolic{T}
45
45
coeff:: Any = 0 # exp/den if Pow
46
46
dict:: EMPTY_DICT_T = EMPTY_DICT
47
47
hash:: RefValue{UInt} = EMPTY_HASH
48
48
arguments:: Vector{Any} = EMPTY_ARGS
49
49
issorted:: RefValue{Bool} = NOT_SORTED
50
50
end
51
- struct Div{T} <: BasicSymbolic{T}
51
+ mutable struct Div{T} <: BasicSymbolic{T}
52
52
num:: Any = 1
53
53
den:: Any = 1
54
54
simplified:: Bool = false
55
55
arguments:: Vector{Any} = EMPTY_ARGS
56
56
end
57
- struct Pow{T} <: BasicSymbolic{T}
57
+ mutable struct Pow{T} <: BasicSymbolic{T}
58
58
base:: Any = 1
59
59
exp:: Any = 1
60
60
arguments:: Vector{Any} = EMPTY_ARGS
@@ -77,6 +77,8 @@ function exprtype(x::BasicSymbolic)
77
77
end
78
78
end
79
79
80
+ const wvd = WeakValueDict {UInt, BasicSymbolic} ()
81
+
80
82
# Same but different error messages
81
83
@noinline error_on_type () = error (" Internal error: unreachable reached!" )
82
84
@noinline error_sym () = error (" Sym doesn't have a operation or arguments!" )
@@ -92,7 +94,11 @@ const SIMPLIFIED = 0x01 << 0
92
94
function ConstructionBase. setproperties (obj:: BasicSymbolic{T} , patch:: NamedTuple ):: BasicSymbolic{T} where T
93
95
nt = getproperties (obj)
94
96
nt_new = merge (nt, patch)
95
- Unityper. rt_constructor (obj){T}(;nt_new... )
97
+ # Call outer constructor because hash consing cannot be applied in inner constructor
98
+ @compactified obj:: BasicSymbolic begin
99
+ Sym => Sym {T} (nt_new. name; nt_new... )
100
+ _ => Unityper. rt_constructor (obj){T}(;nt_new... )
101
+ end
96
102
end
97
103
98
104
# ##
@@ -265,6 +271,26 @@ function _isequal(a, b, E)
265
271
end
266
272
end
267
273
274
+ """
275
+ $(TYPEDSIGNATURES)
276
+
277
+ Checks for equality between two `BasicSymbolic` objects, considering both their
278
+ values and metadata.
279
+
280
+ The default `Base.isequal` function for `BasicSymbolic` only compares their expressions
281
+ and ignores metadata. This does not help deal with hash collisions when metadata is
282
+ relevant for distinguishing expressions, particularly in hashing contexts. This function
283
+ provides a stricter equality check that includes metadata comparison, preventing
284
+ such collisions.
285
+
286
+ Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` and
287
+ downstream packages like `ModelingToolkit.jl`, hence the need for this separate
288
+ function.
289
+ """
290
+ function isequal_with_metadata (a:: BasicSymbolic , b:: BasicSymbolic ):: Bool
291
+ isequal (a, b) && isequal (metadata (a), metadata (b))
292
+ end
293
+
268
294
Base. one ( s:: Symbolic ) = one ( symtype (s))
269
295
Base. zero (s:: Symbolic ) = zero (symtype (s))
270
296
@@ -307,12 +333,61 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
307
333
end
308
334
end
309
335
336
+ """
337
+ $(TYPEDSIGNATURES)
338
+
339
+ Calculates a hash value for a `BasicSymbolic` object, incorporating both its metadata and
340
+ symtype.
341
+
342
+ This function provides an alternative hashing strategy to `Base.hash` for `BasicSymbolic`
343
+ objects. Unlike `Base.hash`, which only considers the expression structure, `hash2` also
344
+ includes the metadata and symtype in the hash calculation. This can be beneficial for hash
345
+ consing, allowing for more effective deduplication of symbolically equivalent expressions
346
+ with different metadata or symtypes.
347
+ """
348
+ hash2 (s:: BasicSymbolic ) = hash2 (s, zero (UInt))
349
+ function hash2 (s:: BasicSymbolic{T} , salt:: UInt ):: UInt where {T}
350
+ hash (metadata (s), hash (T, hash (s, salt)))
351
+ end
352
+
310
353
# ##
311
354
# ## Constructors
312
355
# ##
313
356
314
- function Sym {T} (name:: Symbol ; kw... ) where T
315
- Sym {T} (; name= name, kw... )
357
+ """
358
+ $(TYPEDSIGNATURES)
359
+
360
+ Implements hash consing (flyweight design pattern) for `BasicSymbolic` objects.
361
+
362
+ This function checks if an equivalent `BasicSymbolic` object already exists. It uses a
363
+ custom hash function (`hash2`) incorporating metadata and symtypes to search for existing
364
+ objects in a `WeakValueDict` (`wvd`). Due to the possibility of hash collisions (where
365
+ different objects produce the same hash), a custom equality check (`isequal_with_metadata`)
366
+ which includes metadata comparison, is used to confirm the equivalence of objects with
367
+ matching hashes. If an equivalent object is found, the existing object is returned;
368
+ otherwise, the input `s` is returned. This reduces memory usage, improves compilation time
369
+ for runtime code generation, and supports built-in common subexpression elimination,
370
+ particularly when working with symbolic objects with metadata.
371
+
372
+ Using a `WeakValueDict` ensures that only weak references to `BasicSymbolic` objects are
373
+ stored, allowing objects that are no longer strongly referenced to be garbage collected.
374
+ Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.hash` and
375
+ `Base.isequal` to accommodate metadata without disrupting existing tests reliant on the
376
+ original behavior of those functions.
377
+ """
378
+ function BasicSymbolic (s:: BasicSymbolic ):: BasicSymbolic
379
+ h = hash2 (s)
380
+ t = get! (wvd, h, s)
381
+ if t === s || isequal_with_metadata (t, s)
382
+ return t
383
+ else
384
+ return s
385
+ end
386
+ end
387
+
388
+ function Sym {T} (name:: Symbol ; kw... ) where {T}
389
+ s = Sym {T} (; name, kw... )
390
+ BasicSymbolic (s)
316
391
end
317
392
318
393
function Term {T} (f, args; kw... ) where T
0 commit comments