@@ -239,12 +239,12 @@ Base.isequal(x, ::Symbolic) = false
239
239
Base. isequal (:: Symbolic , :: Missing ) = false
240
240
Base. isequal (:: Missing , :: Symbolic ) = false
241
241
Base. isequal (:: Symbolic , :: Symbolic ) = false
242
- coeff_isequal (a, b) = isequal (a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a== b))
243
- function _allarequal (xs, ys):: Bool
242
+ coeff_isequal (a, b; comparator = isequal) = comparator (a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a== b))
243
+ function _allarequal (xs, ys; comparator = isequal ):: Bool
244
244
N = length (xs)
245
245
length (ys) == N || return false
246
246
for n = 1 : N
247
- isequal (xs[n], ys[n]) || return false
247
+ comparator (xs[n], ys[n]) || return false
248
248
end
249
249
return true
250
250
end
@@ -258,19 +258,19 @@ function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
258
258
T === S || return false
259
259
return _isequal (a, b, E):: Bool
260
260
end
261
- function _isequal (a, b, E)
261
+ function _isequal (a, b, E; comparator = isequal )
262
262
if E === SYM
263
263
nameof (a) === nameof (b)
264
264
elseif E === ADD || E === MUL
265
- coeff_isequal (a. coeff, b. coeff) && isequal (a. dict, b. dict)
265
+ coeff_isequal (a. coeff, b. coeff; comparator ) && comparator (a. dict, b. dict)
266
266
elseif E === DIV
267
- isequal (a. num, b. num) && isequal (a. den, b. den)
267
+ comparator (a. num, b. num) && comparator (a. den, b. den)
268
268
elseif E === POW
269
- isequal (a. exp, b. exp) && isequal (a. base, b. base)
269
+ comparator (a. exp, b. exp) && comparator (a. base, b. base)
270
270
elseif E === TERM
271
271
a1 = arguments (a)
272
272
a2 = arguments (b)
273
- isequal (operation (a), operation (b)) && _allarequal (a1, a2)
273
+ comparator (operation (a), operation (b)) && _allarequal (a1, a2; comparator )
274
274
else
275
275
error_on_type ()
276
276
end
@@ -292,8 +292,14 @@ Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` an
292
292
downstream packages like `ModelingToolkit.jl`, hence the need for this separate
293
293
function.
294
294
"""
295
- function isequal_with_metadata (a:: BasicSymbolic , b:: BasicSymbolic ):: Bool
296
- isequal (a, b) && isequal_with_metadata (metadata (a), metadata (b))
295
+ function isequal_with_metadata (a:: BasicSymbolic{T} , b:: BasicSymbolic{S} ):: Bool where {T, S}
296
+ a === b && return true
297
+
298
+ E = exprtype (a)
299
+ E === exprtype (b) || return false
300
+
301
+ T === S || return false
302
+ _isequal (a, b, E; comparator = isequal_with_metadata):: Bool && isequal_with_metadata (metadata (a), metadata (b)) || return false
297
303
end
298
304
299
305
"""
@@ -303,9 +309,9 @@ Compare the metadata of two `BasicSymbolic`s to ensure it is equal, recursively
303
309
`isequal_with_metadata` to ensure symbolic variables in the metadata also have equal
304
310
metadata.
305
311
"""
306
- function isequal_with_metadata (a:: Union{AbstractDict, NamedTuple} , b:: Union{AbstractDict, NamedTuple} )
312
+ function isequal_with_metadata (a:: NamedTuple , b:: NamedTuple )
313
+ a === b && return true
307
314
typeof (a) == typeof (b) || return false
308
- length (a) == length (b) || return false
309
315
310
316
for (k, v) in pairs (a)
311
317
haskey (b, k) || return false
@@ -320,6 +326,36 @@ function isequal_with_metadata(a::Union{AbstractDict, NamedTuple}, b::Union{Abst
320
326
return true
321
327
end
322
328
329
+ function isequal_with_metadata (a:: AbstractDict , b:: AbstractDict )
330
+ a === b && return true
331
+ typeof (a) == typeof (b) || return false
332
+ length (a) == length (b) || return false
333
+
334
+ akeys = collect (keys (a))
335
+ avisited = falses (length (akeys))
336
+ bkeys = collect (keys (b))
337
+ bvisited = falses (length (bkeys))
338
+
339
+ for k in akeys
340
+ idx = findfirst (eachindex (bkeys)) do i
341
+ ! bvisited[i] && isequal_with_metadata (k, bkeys[i])
342
+ end
343
+ idx === nothing && return false
344
+ bvisited[idx] = true
345
+ isequal_with_metadata (a[k], b[bkeys[idx]]) || return false
346
+ end
347
+ for (j, k) in enumerate (bkeys)
348
+ bvisited[j] && continue
349
+ idx = findfirst (eachindex (akeys)) do i
350
+ ! avisited[i] && isequal_with_metadata (k, akeys[i])
351
+ end
352
+ idx === nothing && return false
353
+ avisited[idx] = true
354
+ isequal_with_metadata (b[k], a[akeys[idx]]) || return false
355
+ end
356
+ return true
357
+ end
358
+
323
359
"""
324
360
$(TYPEDSIGNATURES)
325
361
@@ -341,6 +377,7 @@ Check if two arrays/tuples are equal by calling `isequal_with_metadata` on each
341
377
This is to ensure true equality of any symbolic elements, if present.
342
378
"""
343
379
function isequal_with_metadata (a:: Union{AbstractArray, Tuple} , b:: Union{AbstractArray, Tuple} )
380
+ a === b && return true
344
381
typeof (a) == typeof (b) || return false
345
382
if a isa AbstractArray
346
383
size (a) == size (b) || return false
0 commit comments