@@ -239,12 +239,12 @@ Base.isequal(x, ::Symbolic) = false
239
239
Base. isequal (:: Symbolic , :: Missing ) = false
240
240
Base. isequal (:: Missing , :: Symbolic ) = false
241
241
Base. isequal (:: Symbolic , :: Symbolic ) = false
242
- coeff_isequal (a, b) = isequal (a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a== b))
243
- function _allarequal (xs, ys):: Bool
242
+ coeff_isequal (a, b; comparator = isequal) = comparator (a, b) || ((a isa AbstractFloat || b isa AbstractFloat) && (a== b))
243
+ function _allarequal (xs, ys; comparator = isequal ):: Bool
244
244
N = length (xs)
245
245
length (ys) == N || return false
246
246
for n = 1 : N
247
- isequal (xs[n], ys[n]) || return false
247
+ comparator (xs[n], ys[n]) || return false
248
248
end
249
249
return true
250
250
end
@@ -258,19 +258,19 @@ function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S}
258
258
T === S || return false
259
259
return _isequal (a, b, E):: Bool
260
260
end
261
- function _isequal (a, b, E)
261
+ function _isequal (a, b, E; comparator = isequal )
262
262
if E === SYM
263
263
nameof (a) === nameof (b)
264
264
elseif E === ADD || E === MUL
265
- coeff_isequal (a. coeff, b. coeff) && isequal (a. dict, b. dict)
265
+ coeff_isequal (a. coeff, b. coeff; comparator ) && comparator (a. dict, b. dict)
266
266
elseif E === DIV
267
- isequal (a. num, b. num) && isequal (a. den, b. den)
267
+ comparator (a. num, b. num) && comparator (a. den, b. den)
268
268
elseif E === POW
269
- isequal (a. exp, b. exp) && isequal (a. base, b. base)
269
+ comparator (a. exp, b. exp) && comparator (a. base, b. base)
270
270
elseif E === TERM
271
271
a1 = arguments (a)
272
272
a2 = arguments (b)
273
- isequal (operation (a), operation (b)) && _allarequal (a1, a2)
273
+ comparator (operation (a), operation (b)) && _allarequal (a1, a2; comparator )
274
274
else
275
275
error_on_type ()
276
276
end
@@ -292,8 +292,100 @@ Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` an
292
292
downstream packages like `ModelingToolkit.jl`, hence the need for this separate
293
293
function.
294
294
"""
295
- function isequal_with_metadata (a:: BasicSymbolic , b:: BasicSymbolic ):: Bool
296
- isequal (a, b) && isequal (metadata (a), metadata (b))
295
+ function isequal_with_metadata (a:: BasicSymbolic{T} , b:: BasicSymbolic{S} ):: Bool where {T, S}
296
+ a === b && return true
297
+
298
+ E = exprtype (a)
299
+ E === exprtype (b) || return false
300
+
301
+ T === S || return false
302
+ _isequal (a, b, E; comparator = isequal_with_metadata):: Bool && isequal_with_metadata (metadata (a), metadata (b)) || return false
303
+ end
304
+
305
+ """
306
+ $(TYPEDSIGNATURES)
307
+
308
+ Compare the metadata of two `BasicSymbolic`s to ensure it is equal, recursively calling
309
+ `isequal_with_metadata` to ensure symbolic variables in the metadata also have equal
310
+ metadata.
311
+ """
312
+ function isequal_with_metadata (a:: NamedTuple , b:: NamedTuple )
313
+ a === b && return true
314
+ typeof (a) == typeof (b) || return false
315
+
316
+ for (k, v) in pairs (a)
317
+ haskey (b, k) || return false
318
+ isequal_with_metadata (v, b[k]) || return false
319
+ end
320
+
321
+ for (k, v) in pairs (b)
322
+ haskey (a, k) || return false
323
+ isequal_with_metadata (v, a[k]) || return false
324
+ end
325
+
326
+ return true
327
+ end
328
+
329
+ function isequal_with_metadata (a:: AbstractDict , b:: AbstractDict )
330
+ a === b && return true
331
+ typeof (a) == typeof (b) || return false
332
+ length (a) == length (b) || return false
333
+
334
+ akeys = collect (keys (a))
335
+ avisited = falses (length (akeys))
336
+ bkeys = collect (keys (b))
337
+ bvisited = falses (length (bkeys))
338
+
339
+ for k in akeys
340
+ idx = findfirst (eachindex (bkeys)) do i
341
+ ! bvisited[i] && isequal_with_metadata (k, bkeys[i])
342
+ end
343
+ idx === nothing && return false
344
+ bvisited[idx] = true
345
+ isequal_with_metadata (a[k], b[bkeys[idx]]) || return false
346
+ end
347
+ for (j, k) in enumerate (bkeys)
348
+ bvisited[j] && continue
349
+ idx = findfirst (eachindex (akeys)) do i
350
+ ! avisited[i] && isequal_with_metadata (k, akeys[i])
351
+ end
352
+ idx === nothing && return false
353
+ avisited[idx] = true
354
+ isequal_with_metadata (b[k], a[akeys[idx]]) || return false
355
+ end
356
+ return true
357
+ end
358
+
359
+ """
360
+ $(TYPEDSIGNATURES)
361
+
362
+ Fallback method which uses `isequal`.
363
+ """
364
+ isequal_with_metadata (a, b) = isequal (a, b)
365
+
366
+ """
367
+ $(TYPEDSIGNATURES)
368
+
369
+ Specialized methods to check if two ranges are equal without comparing each element.
370
+ """
371
+ isequal_with_metadata (a:: AbstractRange , b:: AbstractRange ) = isequal (a, b)
372
+
373
+ """
374
+ $(TYPEDSIGNATURES)
375
+
376
+ Check if two arrays/tuples are equal by calling `isequal_with_metadata` on each element.
377
+ This is to ensure true equality of any symbolic elements, if present.
378
+ """
379
+ function isequal_with_metadata (a:: Union{AbstractArray, Tuple} , b:: Union{AbstractArray, Tuple} )
380
+ a === b && return true
381
+ typeof (a) == typeof (b) || return false
382
+ if a isa AbstractArray
383
+ size (a) == size (b) || return false
384
+ end # otherwise they're tuples and type equality also checks length equality
385
+ for (x, y) in zip (a, b)
386
+ isequal_with_metadata (x, y) || return false
387
+ end
388
+ return true
297
389
end
298
390
299
391
Base. one ( s:: Symbolic ) = one ( symtype (s))
0 commit comments