3
3
# ### Symbolic
4
4
# --------------------
5
5
abstract type Symbolic{T} end
6
+ using Setfield: MacroTools
6
7
7
8
# ################### SafeReal #########################
8
9
export SafeReal, LiteralReal
404
405
using Base. ScopedValues
405
406
406
407
const SV_COMPARE = ScopedValue {Int} ()
408
+ const COMPARE_TYPE = TaskLocalValue {Int} (Returns (0 ))
409
+
410
+ macro manually_scope (val, expr, is_forced = false )
411
+ @assert Meta. isexpr (val, :call )
412
+ @assert val. args[1 ] == :(=> )
413
+
414
+ var_name = val. args[2 ]
415
+ new_val = val. args[3 ]
416
+ old_name = gensym (:old_val )
417
+ cur_name = gensym (:cur_val )
418
+ retval_name = gensym (:retval )
419
+ close_expr = :($ var_name[] = $ old_name)
420
+ interpolated_expr = MacroTools. postwalk (expr) do ex
421
+ if Meta. isexpr (ex, :return )
422
+ return Expr (:block , close_expr, ex)
423
+ elseif Meta. isexpr (ex, :$ ) && length (ex. args) == 1 && ex. args[1 ] == :$
424
+ return cur_name
425
+ else
426
+ return ex
427
+ end
428
+ end
429
+ basic_result = quote
430
+ $ cur_name = $ var_name[] = $ new_val
431
+ $ retval_name = begin
432
+ $ interpolated_expr
433
+ end
434
+ $ close_expr
435
+ $ retval_name
436
+ end
437
+ is_forced && return quote
438
+ $ old_name = $ var_name[]
439
+ $ basic_result
440
+ end |> esc
441
+
442
+ return quote
443
+ $ old_name = $ var_name[]
444
+ if $ iszero ($ old_name)
445
+ $ basic_result
446
+ else
447
+ $ cur_name = $ old_name
448
+ $ retval_name = begin
449
+ $ interpolated_expr
450
+ end
451
+ end
452
+ $ retval_name
453
+ end |> esc
454
+ end
407
455
408
456
function isequal_symdict (a:: Dict , b:: Dict , val)
409
457
if val == 2
@@ -413,11 +461,11 @@ function isequal_symdict(a::Dict, b::Dict, val)
413
461
for (k, v) in a
414
462
k2 = nothing
415
463
v2 = nothing
416
- @with SV_COMPARE => 2 begin
464
+ @manually_scope COMPARE_TYPE => 2 begin
417
465
k2 = getkey (b, k, nothing )
418
466
k2 === nothing && return false
419
467
v2 = b[k2]
420
- end
468
+ end true
421
469
v == v2 && isequal (k, k2) || return false
422
470
end
423
471
return true
@@ -462,13 +510,8 @@ function Base.isequal(a::BSImpl.Type, b::BSImpl.Type)
462
510
Tb = MData. variant_type (b)
463
511
Ta === Tb || return false
464
512
465
- val = ScopedValues. get (SV_COMPARE)
466
- if val === nothing
467
- @with SV_COMPARE => 1 begin
468
- isequal_bsimpl (a, b, 1 )
469
- end
470
- else
471
- isequal_bsimpl (a, b, something (val))
513
+ @manually_scope COMPARE_TYPE => 1 begin
514
+ isequal_bsimpl (a, b, $$ )
472
515
end
473
516
end
474
517
@@ -477,12 +520,7 @@ function Base.isequal(a::BasicSymbolic, b::BasicSymbolic)
477
520
typeof (a) === typeof (b) || return false
478
521
479
522
480
- val = ScopedValues. get (SV_COMPARE)
481
- if val === nothing
482
- @with SV_COMPARE => 2 begin
483
- isequal (_unwrap_internal (a), _unwrap_internal (b))
484
- end
485
- else
523
+ @manually_scope COMPARE_TYPE => 2 begin
486
524
isequal (_unwrap_internal (a), _unwrap_internal (b))
487
525
end
488
526
end
492
530
for T1 in [BasicSymbolic, BSImpl. Type], T2 in [BasicSymbolic, BSImpl. Type]
493
531
T1 == T2 && continue
494
532
@eval function Base. isequal (a:: $T1 , b:: $T2 )
495
- val = ScopedValues. get (SV_COMPARE)
496
- if val === nothing
497
- @with SV_COMPARE => 2 begin
498
- isequal (_unwrap_internal (a), _unwrap_internal (b))
499
- end
500
- else
533
+ @manually_scope COMPARE_TYPE => 2 begin
501
534
isequal (_unwrap_internal (a), _unwrap_internal (b))
502
535
end
503
536
end
@@ -639,30 +672,14 @@ function Base.hash(s::BSImpl.Type, h::UInt)
639
672
if ! iszero (h)
640
673
return hash (hash (s, zero (h)), h):: UInt
641
674
end
642
- val = ScopedValues. get (SV_COMPARE)
643
- if val === nothing
644
- @with SV_COMPARE => 1 begin
645
- hash_bsimpl (s, h, 1 )
646
- end
647
- else
648
- hash_bsimpl (s, h, something (val))
675
+ @manually_scope COMPARE_TYPE => 1 begin
676
+ hash_bsimpl (s, h, $$ )
649
677
end
650
678
end
651
679
652
680
Base. @nospecializeinfer function Base. hash (x:: BasicSymbolic , h:: UInt )
653
681
@nospecialize x
654
- val = ScopedValues. get (SV_COMPARE)
655
- if val === nothing
656
- @with SV_COMPARE => 2 begin
657
- if x isa BasicSymbolic{Real}
658
- result = Base. hash (_unwrap_internal (x), h)
659
- elseif x isa BasicSymbolic{Number}
660
- result = Base. hash (_unwrap_internal (x), h)
661
- else
662
- result = Base. hash (_unwrap_internal (x), h)
663
- end
664
- end
665
- else
682
+ @manually_scope COMPARE_TYPE => 2 begin
666
683
if x isa BasicSymbolic{Real}
667
684
result = Base. hash (_unwrap_internal (x), h)
668
685
elseif x isa BasicSymbolic{Number}
0 commit comments