Skip to content

Commit 9020755

Browse files
fix: compare metadata of entire expression tree in hashconsing
1 parent 0e8ecbc commit 9020755

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

src/types.jl

Lines changed: 49 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,12 @@ Base.isequal(x, ::Symbolic) = false
239239
Base.isequal(::Symbolic, ::Missing) = false
240240
Base.isequal(::Missing, ::Symbolic) = false
241241
Base.isequal(::Symbolic, ::Symbolic) = false
242-
coeff_isequal(a, b) = isequal(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b))
243-
function _allarequal(xs, ys)::Bool
242+
coeff_isequal(a, b; comparator = isequal) = comparator(a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a==b))
243+
function _allarequal(xs, ys; comparator = isequal)::Bool
244244
N = length(xs)
245245
length(ys) == N || return false
246246
for n = 1:N
247-
isequal(xs[n], ys[n]) || return false
247+
comparator(xs[n], ys[n]) || return false
248248
end
249249
return true
250250
end
@@ -258,19 +258,19 @@ function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
258258
T === S || return false
259259
return _isequal(a, b, E)::Bool
260260
end
261-
function _isequal(a, b, E)
261+
function _isequal(a, b, E; comparator = isequal)
262262
if E === SYM
263263
nameof(a) === nameof(b)
264264
elseif E === ADD || E === MUL
265-
coeff_isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)
265+
coeff_isequal(a.coeff, b.coeff; comparator) && comparator(a.dict, b.dict)
266266
elseif E === DIV
267-
isequal(a.num, b.num) && isequal(a.den, b.den)
267+
comparator(a.num, b.num) && comparator(a.den, b.den)
268268
elseif E === POW
269-
isequal(a.exp, b.exp) && isequal(a.base, b.base)
269+
comparator(a.exp, b.exp) && comparator(a.base, b.base)
270270
elseif E === TERM
271271
a1 = arguments(a)
272272
a2 = arguments(b)
273-
isequal(operation(a), operation(b)) && _allarequal(a1, a2)
273+
comparator(operation(a), operation(b)) && _allarequal(a1, a2; comparator)
274274
else
275275
error_on_type()
276276
end
@@ -292,8 +292,14 @@ Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` an
292292
downstream packages like `ModelingToolkit.jl`, hence the need for this separate
293293
function.
294294
"""
295-
function isequal_with_metadata(a::BasicSymbolic, b::BasicSymbolic)::Bool
296-
isequal(a, b) && isequal_with_metadata(metadata(a), metadata(b))
295+
function isequal_with_metadata(a::BasicSymbolic{T}, b::BasicSymbolic{S})::Bool where {T, S}
296+
a === b && return true
297+
298+
E = exprtype(a)
299+
E === exprtype(b) || return false
300+
301+
T === S || return false
302+
_isequal(a, b, E; comparator = isequal_with_metadata)::Bool && isequal_with_metadata(metadata(a), metadata(b)) || return false
297303
end
298304

299305
"""
@@ -303,9 +309,9 @@ Compare the metadata of two `BasicSymbolic`s to ensure it is equal, recursively
303309
`isequal_with_metadata` to ensure symbolic variables in the metadata also have equal
304310
metadata.
305311
"""
306-
function isequal_with_metadata(a::Union{AbstractDict, NamedTuple}, b::Union{AbstractDict, NamedTuple})
312+
function isequal_with_metadata(a::NamedTuple, b::NamedTuple)
313+
a === b && return true
307314
typeof(a) == typeof(b) || return false
308-
length(a) == length(b) || return false
309315

310316
for (k, v) in pairs(a)
311317
haskey(b, k) || return false
@@ -320,6 +326,36 @@ function isequal_with_metadata(a::Union{AbstractDict, NamedTuple}, b::Union{Abst
320326
return true
321327
end
322328

329+
function isequal_with_metadata(a::AbstractDict, b::AbstractDict)
330+
a === b && return true
331+
typeof(a) == typeof(b) || return false
332+
length(a) == length(b) || return false
333+
334+
akeys = collect(keys(a))
335+
avisited = falses(length(akeys))
336+
bkeys = collect(keys(b))
337+
bvisited = falses(length(bkeys))
338+
339+
for k in akeys
340+
idx = findfirst(eachindex(bkeys)) do i
341+
!bvisited[i] && isequal_with_metadata(k, bkeys[i])
342+
end
343+
idx === nothing && return false
344+
bvisited[idx] = true
345+
isequal_with_metadata(a[k], b[bkeys[idx]]) || return false
346+
end
347+
for (j, k) in enumerate(bkeys)
348+
bvisited[j] && continue
349+
idx = findfirst(eachindex(akeys)) do i
350+
!avisited[i] && isequal_with_metadata(k, akeys[i])
351+
end
352+
idx === nothing && return false
353+
avisited[idx] = true
354+
isequal_with_metadata(b[k], a[akeys[idx]]) || return false
355+
end
356+
return true
357+
end
358+
323359
"""
324360
$(TYPEDSIGNATURES)
325361
@@ -341,6 +377,7 @@ Check if two arrays/tuples are equal by calling `isequal_with_metadata` on each
341377
This is to ensure true equality of any symbolic elements, if present.
342378
"""
343379
function isequal_with_metadata(a::Union{AbstractArray, Tuple}, b::Union{AbstractArray, Tuple})
380+
a === b && return true
344381
typeof(a) == typeof(b) || return false
345382
if a isa AbstractArray
346383
size(a) == size(b) || return false

test/hash_consing.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,12 @@ end
119119
@test metadata(metadata(a1)[Int]) === nothing
120120
@test metadata(metadata(a2)[Int])[Int] == 3
121121
end
122+
123+
@testset "Compare metadata of expression tree" begin
124+
@syms a b
125+
aa = setmetadata(a, Int, b)
126+
@test aa !== a
127+
@test isequal(a, aa)
128+
@test !SymbolicUtils.isequal_with_metadata(a, aa)
129+
@test !SymbolicUtils.isequal_with_metadata(2a, 2aa)
130+
end

0 commit comments

Comments
 (0)