Skip to content

Commit d0f5a89

Browse files
use macro and TLV for scoping
1 parent 929e477 commit d0f5a89

File tree

3 files changed

+60
-44
lines changed

3 files changed

+60
-44
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
1717
ExproniconLite = "55351af7-c7e9-48d6-89ff-24e801d99491"
1818
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1919
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
20+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2021
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
2122
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
2223
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
@@ -54,6 +55,7 @@ DynamicPolynomials = "0.5, 0.6"
5455
EnumX = "1.0.5"
5556
ExproniconLite = "0.10.14"
5657
LabelledArrays = "1.5"
58+
MacroTools = "0.5.16"
5759
Moshi = "0.3.6"
5860
MultivariatePolynomials = "0.5"
5961
NaNMath = "0.3, 1.1.2"
@@ -89,7 +91,3 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
8991

9092
[targets]
9193
test = ["BenchmarkTools", "Documenter", "LabelledArrays", "Pkg", "PkgBenchmark", "Random", "ReferenceTests", "ReverseDiff", "SafeTestsets", "Test", "Zygote", "OhMyThreads", "RuntimeGeneratedFunctions"]
92-
93-
[sources]
94-
WeakCacheSets = {url="https://github.com/JuliaCollections/WeakCacheSets.jl"}
95-
Moshi = {url="https://github.com/AayushSabharwal/Moshi.jl", rev="as/mutable-adt"}

src/SymbolicUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import TaskLocalValues: TaskLocalValue
3030
using WeakValueDicts: WeakValueDict
3131
using WeakCacheSets: WeakCacheSet, getkey!
3232
using Base: RefValue
33+
import MacroTools
3334

3435
# include("WeakCacheSets.jl")
3536

src/types.jl

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#### Symbolic
44
#--------------------
55
abstract type Symbolic{T} end
6+
using Setfield: MacroTools
67

78
#################### SafeReal #########################
89
export SafeReal, LiteralReal
@@ -404,6 +405,53 @@ end
404405
using Base.ScopedValues
405406

406407
const SV_COMPARE = ScopedValue{Int}()
408+
const COMPARE_TYPE = TaskLocalValue{Int}(Returns(0))
409+
410+
macro manually_scope(val, expr, is_forced = false)
411+
@assert Meta.isexpr(val, :call)
412+
@assert val.args[1] == :(=>)
413+
414+
var_name = val.args[2]
415+
new_val = val.args[3]
416+
old_name = gensym(:old_val)
417+
cur_name = gensym(:cur_val)
418+
retval_name = gensym(:retval)
419+
close_expr = :($var_name[] = $old_name)
420+
interpolated_expr = MacroTools.postwalk(expr) do ex
421+
if Meta.isexpr(ex, :return)
422+
return Expr(:block, close_expr, ex)
423+
elseif Meta.isexpr(ex, :$) && length(ex.args) == 1 && ex.args[1] == :$
424+
return cur_name
425+
else
426+
return ex
427+
end
428+
end
429+
basic_result = quote
430+
$cur_name = $var_name[] = $new_val
431+
$retval_name = begin
432+
$interpolated_expr
433+
end
434+
$close_expr
435+
$retval_name
436+
end
437+
is_forced && return quote
438+
$old_name = $var_name[]
439+
$basic_result
440+
end |> esc
441+
442+
return quote
443+
$old_name = $var_name[]
444+
if $iszero($old_name)
445+
$basic_result
446+
else
447+
$cur_name = $old_name
448+
$retval_name = begin
449+
$interpolated_expr
450+
end
451+
end
452+
$retval_name
453+
end |> esc
454+
end
407455

408456
function isequal_symdict(a::Dict, b::Dict, val)
409457
if val == 2
@@ -413,11 +461,11 @@ function isequal_symdict(a::Dict, b::Dict, val)
413461
for (k, v) in a
414462
k2 = nothing
415463
v2 = nothing
416-
@with SV_COMPARE => 2 begin
464+
@manually_scope COMPARE_TYPE => 2 begin
417465
k2 = getkey(b, k, nothing)
418466
k2 === nothing && return false
419467
v2 = b[k2]
420-
end
468+
end true
421469
v == v2 && isequal(k, k2) || return false
422470
end
423471
return true
@@ -462,13 +510,8 @@ function Base.isequal(a::BSImpl.Type, b::BSImpl.Type)
462510
Tb = MData.variant_type(b)
463511
Ta === Tb || return false
464512

465-
val = ScopedValues.get(SV_COMPARE)
466-
if val === nothing
467-
@with SV_COMPARE => 1 begin
468-
isequal_bsimpl(a, b, 1)
469-
end
470-
else
471-
isequal_bsimpl(a, b, something(val))
513+
@manually_scope COMPARE_TYPE => 1 begin
514+
isequal_bsimpl(a, b, $$)
472515
end
473516
end
474517

@@ -477,12 +520,7 @@ function Base.isequal(a::BasicSymbolic, b::BasicSymbolic)
477520
typeof(a) === typeof(b) || return false
478521

479522

480-
val = ScopedValues.get(SV_COMPARE)
481-
if val === nothing
482-
@with SV_COMPARE => 2 begin
483-
isequal(_unwrap_internal(a), _unwrap_internal(b))
484-
end
485-
else
523+
@manually_scope COMPARE_TYPE => 2 begin
486524
isequal(_unwrap_internal(a), _unwrap_internal(b))
487525
end
488526
end
@@ -492,12 +530,7 @@ end
492530
for T1 in [BasicSymbolic, BSImpl.Type], T2 in [BasicSymbolic, BSImpl.Type]
493531
T1 == T2 && continue
494532
@eval function Base.isequal(a::$T1, b::$T2)
495-
val = ScopedValues.get(SV_COMPARE)
496-
if val === nothing
497-
@with SV_COMPARE => 2 begin
498-
isequal(_unwrap_internal(a), _unwrap_internal(b))
499-
end
500-
else
533+
@manually_scope COMPARE_TYPE => 2 begin
501534
isequal(_unwrap_internal(a), _unwrap_internal(b))
502535
end
503536
end
@@ -639,30 +672,14 @@ function Base.hash(s::BSImpl.Type, h::UInt)
639672
if !iszero(h)
640673
return hash(hash(s, zero(h)), h)::UInt
641674
end
642-
val = ScopedValues.get(SV_COMPARE)
643-
if val === nothing
644-
@with SV_COMPARE => 1 begin
645-
hash_bsimpl(s, h, 1)
646-
end
647-
else
648-
hash_bsimpl(s, h, something(val))
675+
@manually_scope COMPARE_TYPE => 1 begin
676+
hash_bsimpl(s, h, $$)
649677
end
650678
end
651679

652680
Base.@nospecializeinfer function Base.hash(x::BasicSymbolic, h::UInt)
653681
@nospecialize x
654-
val = ScopedValues.get(SV_COMPARE)
655-
if val === nothing
656-
@with SV_COMPARE => 2 begin
657-
if x isa BasicSymbolic{Real}
658-
result = Base.hash(_unwrap_internal(x), h)
659-
elseif x isa BasicSymbolic{Number}
660-
result = Base.hash(_unwrap_internal(x), h)
661-
else
662-
result = Base.hash(_unwrap_internal(x), h)
663-
end
664-
end
665-
else
682+
@manually_scope COMPARE_TYPE => 2 begin
666683
if x isa BasicSymbolic{Real}
667684
result = Base.hash(_unwrap_internal(x), h)
668685
elseif x isa BasicSymbolic{Number}

0 commit comments

Comments
 (0)