Skip to content

Commit 406f5b4

Browse files
authored
Make EnterNode save/restore dynamic scope (#52309)
As discussed in #51352, this gives `EnterNode` the ability to set (and restore on leave or catch edge) jl_current_task->scope. Manual modifications of the task field after the task has started are considered undefined behavior. In addition, we gain a new intrinsic to access current_task->scope and both inference and the optimizer will forward scopes from EnterNodes to this intrinsic (non-interprocedurally). Together with #51993 this is sufficient to fully optimize ScopedValues (non-interprocedurally at least).
1 parent d27ed8f commit 406f5b4

23 files changed

+248
-62
lines changed

base/boot.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ eval(Core, quote
460460
ReturnNode() = $(Expr(:new, :ReturnNode)) # unassigned val indicates unreachable
461461
GotoIfNot(@nospecialize(cond), dest::Int) = $(Expr(:new, :GotoIfNot, :cond, :dest))
462462
EnterNode(dest::Int) = $(Expr(:new, :EnterNode, :dest))
463+
EnterNode(dest::Int, @nospecialize(scope)) = $(Expr(:new, :EnterNode, :dest, :scope))
463464
LineNumberNode(l::Int) = $(Expr(:new, :LineNumberNode, :l, nothing))
464465
function LineNumberNode(l::Int, @nospecialize(f))
465466
isa(f, String) && (f = Symbol(f))
@@ -966,7 +967,8 @@ arraysize(a::Array, i::Int) = sle_int(i, nfields(a.size)) ? getfield(a.size, i)
966967
export arrayref, arrayset, arraysize, const_arrayref
967968

968969
# For convenience
969-
EnterNode(old::EnterNode, new_dest::Int) = EnterNode(new_dest)
970+
EnterNode(old::EnterNode, new_dest::Int) = isdefined(old, :scope) ?
971+
EnterNode(new_dest, old.scope) : EnterNode(new_dest)
970972

971973
include(Core, "optimized_generics.jl")
972974

base/compiler/abstractinterpretation.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3270,6 +3270,19 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
32703270
elseif isa(stmt, EnterNode)
32713271
ssavaluetypes[currpc] = Any
32723272
add_curr_ssaflag!(frame, IR_FLAG_NOTHROW)
3273+
if isdefined(stmt, :scope)
3274+
scopet = abstract_eval_value(interp, stmt.scope, currstate, frame)
3275+
handler = frame.handlers[frame.handler_at[frame.currpc+1][1]]
3276+
@assert handler.scopet !== nothing
3277+
if !(𝕃ᵢ, scopet, handler.scopet)
3278+
handler.scopet = tmerge(𝕃ᵢ, scopet, handler.scopet)
3279+
if isdefined(handler, :scope_uses)
3280+
for bb in handler.scope_uses
3281+
push!(W, bb)
3282+
end
3283+
end
3284+
end
3285+
end
32733286
@goto fallthrough
32743287
elseif isexpr(stmt, :leave)
32753288
ssavaluetypes[currpc] = Any

base/compiler/inferencestate.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,10 @@ const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed
205205

206206
mutable struct TryCatchFrame
207207
exct
208+
scopet
208209
const enter_idx::Int
209-
TryCatchFrame(@nospecialize(exct), enter_idx::Int) = new(exct, enter_idx)
210+
scope_uses::Vector{Int}
211+
TryCatchFrame(@nospecialize(exct), @nospecialize(scopet), enter_idx::Int) = new(exct, scopet, enter_idx)
210212
end
211213

212214
mutable struct InferenceState
@@ -364,7 +366,7 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
364366
stmt = code[pc]
365367
if isa(stmt, EnterNode)
366368
l = stmt.catch_dest
367-
push!(handlers, TryCatchFrame(Bottom, pc))
369+
push!(handlers, TryCatchFrame(Bottom, isdefined(stmt, :scope) ? Bottom : nothing, pc))
368370
handler_id = length(handlers)
369371
handler_at[pc + 1] = (handler_id, 0)
370372
push!(ip, pc + 1)

base/compiler/ssair/ir.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1414,6 +1414,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
14141414
result_idx += 1
14151415
end
14161416
elseif cfg_transforms_enabled && isa(stmt, EnterNode)
1417+
stmt = renumber_ssa2!(stmt, ssa_rename, used_ssas, new_new_used_ssas, late_fixup, result_idx, do_rename_ssa, mark_refined!)::EnterNode
14171418
label = bb_rename_succ[stmt.catch_dest]
14181419
@assert label > 0
14191420
ssa_rename[idx] = SSAValue(result_idx)

base/compiler/ssair/passes.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,29 @@ function fold_ifelse!(compact::IncrementalCompact, idx::Int, stmt::Expr)
10691069
return false
10701070
end
10711071

1072+
function fold_current_scope!(compact::IncrementalCompact, idx::Int, stmt::Expr, lazydomtree::LazyDomtree)
1073+
domtree = get!(lazydomtree)
1074+
1075+
# The frontend enforces the invariant that any :enter dominates its active
1076+
# region, so all we have to do here is walk the domtree to find it.
1077+
dombb = block_for_inst(compact, SSAValue(idx))
1078+
1079+
local bbterminator
1080+
while true
1081+
dombb = domtree.idoms_bb[dombb]
1082+
1083+
# Did not find any dominating :enter - scope is inherited from the outside
1084+
dombb == 0 && return nothing
1085+
1086+
bbterminator = compact[SSAValue(last(compact.cfg_transform.result_bbs[dombb].stmts))][:stmt]
1087+
isa(bbterminator, EnterNode) || continue
1088+
isdefined(bbterminator, :scope) || continue
1089+
compact[idx] = bbterminator.scope
1090+
return nothing
1091+
end
1092+
end
1093+
1094+
10721095
# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
10731096
# which can be very large sometimes, and program counters in question are often very sparse
10741097
const SPCSet = IdSet{Int}
@@ -1201,6 +1224,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
12011224
elseif is_known_invoke_or_call(stmt, Core.OptimizedGenerics.KeyValue.get, compact)
12021225
2 == (length(stmt.args) - (isexpr(stmt, :invoke) ? 2 : 1)) || continue
12031226
lift_keyvalue_get!(compact, idx, stmt, 𝕃ₒ)
1227+
elseif is_known_call(stmt, Core.current_scope, compact)
1228+
fold_current_scope!(compact, idx, stmt, lazydomtree)
12041229
elseif isexpr(stmt, :new)
12051230
refine_new_effects!(𝕃ₒ, compact, idx, stmt)
12061231
end

base/compiler/ssair/show.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ function print_stmt(io::IO, idx::Int, @nospecialize(stmt), used::BitSet, maxleng
6969
# given control flow information, we prefer to print these with the basic block #, instead of the ssa %
7070
elseif isa(stmt, EnterNode)
7171
print(io, "enter #", stmt.catch_dest, "")
72+
if isdefined(stmt, :scope)
73+
print(io, " with scope ")
74+
show_unquoted(io, stmt.scope, indent)
75+
end
7276
elseif stmt isa GotoNode
7377
print(io, "goto #", stmt.label)
7478
elseif stmt isa PhiNode

base/compiler/ssair/verify.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
function maybe_show_ir(ir::IRCode)
44
if isdefined(Core, :Main)
5-
Core.Main.Base.display(ir)
5+
invokelatest(Core.Main.Base.display, ir)
66
end
77
end
88

base/compiler/tfuncs.jl

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,6 +2488,19 @@ function builtin_effects(𝕃::AbstractLattice, @nospecialize(f::Builtin), argty
24882488
return Effects(EFFECTS_TOTAL;
24892489
consistent = (isa(setting, Const) && setting.val === :conditional) ? ALWAYS_TRUE : ALWAYS_FALSE,
24902490
nothrow = compilerbarrier_nothrow(setting, nothing))
2491+
elseif f === Core.current_scope
2492+
nothrow = true
2493+
if length(argtypes) != 0
2494+
if length(argtypes) != 1 || !isvarargtype(argtypes[1])
2495+
return EFFECTS_THROWS
2496+
end
2497+
nothrow = false
2498+
end
2499+
return Effects(EFFECTS_TOTAL;
2500+
consistent = ALWAYS_FALSE,
2501+
notaskstate = false,
2502+
nothrow
2503+
)
24912504
else
24922505
if contains_is(_CONSISTENT_BUILTINS, f)
24932506
consistent = ALWAYS_TRUE
@@ -2554,6 +2567,32 @@ function memoryop_noub(@nospecialize(f), argtypes::Vector{Any})
25542567
return false
25552568
end
25562569

2570+
function current_scope_tfunc(interp::AbstractInterpreter, sv::InferenceState)
2571+
pc = sv.currpc
2572+
while true
2573+
handleridx = sv.handler_at[pc][1]
2574+
if handleridx == 0
2575+
# No local scope available - inherited from the outside
2576+
return Any
2577+
end
2578+
pchandler = sv.handlers[handleridx]
2579+
# Remember that we looked at this handler, so we get re-scheduled
2580+
# if the scope information changes
2581+
isdefined(pchandler, :scope_uses) || (pchandler.scope_uses = Int[])
2582+
pcbb = block_for_inst(sv.cfg, pc)
2583+
if findfirst(==(pcbb), pchandler.scope_uses) === nothing
2584+
push!(pchandler.scope_uses, pcbb)
2585+
end
2586+
scope = pchandler.scopet
2587+
if scope !== nothing
2588+
# Found the scope - forward it
2589+
return scope
2590+
end
2591+
pc = pchandler.enter_idx
2592+
end
2593+
end
2594+
current_scope_tfunc(interp::AbstractInterpreter, sv) = Any
2595+
25572596
"""
25582597
builtin_nothrow(𝕃::AbstractLattice, f::Builtin, argtypes::Vector{Any}, rt) -> Bool
25592598
@@ -2568,9 +2607,6 @@ end
25682607
function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any},
25692608
sv::Union{AbsIntState, Nothing})
25702609
𝕃ᵢ = typeinf_lattice(interp)
2571-
if f === tuple
2572-
return tuple_tfunc(𝕃ᵢ, argtypes)
2573-
end
25742610
if isa(f, IntrinsicFunction)
25752611
if is_pure_intrinsic_infer(f) && all(@nospecialize(a) -> isa(a, Const), argtypes)
25762612
argvals = anymap(@nospecialize(a) -> (a::Const).val, argtypes)
@@ -2596,6 +2632,16 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
25962632
end
25972633
tf = T_IFUNC[iidx]
25982634
else
2635+
if f === tuple
2636+
return tuple_tfunc(𝕃ᵢ, argtypes)
2637+
elseif f === Core.current_scope
2638+
if length(argtypes) != 0
2639+
if length(argtypes) != 1 || !isvarargtype(argtypes[1])
2640+
return Bottom
2641+
end
2642+
end
2643+
return current_scope_tfunc(interp, sv)
2644+
end
25992645
fidx = find_tfunc(f)
26002646
if fidx === nothing
26012647
# unknown/unhandled builtin function

base/compiler/validation.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange{Int}}(
1313
:new => 1:typemax(Int),
1414
:splatnew => 2:2,
1515
:the_exception => 0:0,
16-
:enter => 1:1,
16+
:enter => 1:2,
1717
:leave => 1:typemax(Int),
1818
:pop_exception => 1:1,
1919
:inbounds => 1:1,
@@ -160,6 +160,13 @@ function validate_code!(errors::Vector{InvalidCodeError}, c::CodeInfo, is_top_le
160160
push!(errors, InvalidCodeError(INVALID_CALL_ARG, x.cond))
161161
end
162162
validate_val!(x.cond)
163+
elseif isa(x, EnterNode)
164+
if isdefined(x, :scope)
165+
if !is_valid_argument(x.scope)
166+
push!(errors, InvalidCodeError(INVALID_CALL_ARG, x.scope))
167+
end
168+
validate_val!(x.scope)
169+
end
163170
elseif isa(x, ReturnNode)
164171
if isdefined(x, :val)
165172
if !is_valid_return(x.val)

base/scopedvalues.jl

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,6 @@ function Scope(scope, pair1::Pair{<:ScopedValue}, pair2::Pair{<:ScopedValue}, pa
8181
end
8282
Scope(::Nothing) = nothing
8383

84-
"""
85-
current_scope()::Union{Nothing, Scope}
86-
87-
Return the current dynamic scope.
88-
"""
89-
current_scope() = current_task().scope::Union{Nothing, Scope}
90-
9184
function Base.show(io::IO, scope::Scope)
9285
print(io, Scope, "(")
9386
first = true
@@ -116,8 +109,7 @@ return `nothing`. Otherwise returns `Some{T}` with the current
116109
value.
117110
"""
118111
function get(val::ScopedValue{T}) where {T}
119-
# Inline current_scope to avoid doing the type assertion twice.
120-
scope = current_task().scope
112+
scope = Core.current_scope()::Union{Scope, Nothing}
121113
if scope === nothing
122114
isassigned(val) && return Some{T}(val.default)
123115
return nothing
@@ -151,25 +143,6 @@ function Base.show(io::IO, val::ScopedValue)
151143
print(io, ')')
152144
end
153145

154-
"""
155-
with(f, (var::ScopedValue{T} => val::T)...)
156-
157-
Execute `f` in a new scope with `var` set to `val`.
158-
"""
159-
function with(f, pair::Pair{<:ScopedValue}, rest::Pair{<:ScopedValue}...)
160-
@nospecialize
161-
ct = Base.current_task()
162-
current_scope = ct.scope::Union{Nothing, Scope}
163-
ct.scope = Scope(current_scope, pair, rest...)
164-
try
165-
return f()
166-
finally
167-
ct.scope = current_scope
168-
end
169-
end
170-
171-
with(@nospecialize(f)) = f()
172-
173146
"""
174147
@with vars... expr
175148
@@ -187,18 +160,18 @@ macro with(exprs...)
187160
else
188161
error("@with expects at least one argument")
189162
end
190-
for expr in exprs
191-
if expr.head !== :call || first(expr.args) !== :(=>)
192-
error("@with expects arguments of the form `A => 2` got $expr")
193-
end
194-
end
195163
exprs = map(esc, exprs)
196-
quote
197-
ct = $(Base.current_task)()
198-
current_scope = ct.scope::$(Union{Nothing, Scope})
199-
ct.scope = $(Scope)(current_scope, $(exprs...))
200-
$(Expr(:tryfinally, esc(ex), :(ct.scope = current_scope)))
201-
end
164+
Expr(:tryfinally, esc(ex), :(), :(Scope(Core.current_scope()::Union{Nothing, Scope}, $(exprs...))))
202165
end
203166

167+
"""
168+
with(f, (var::ScopedValue{T} => val::T)...)
169+
170+
Execute `f` in a new scope with `var` set to `val`.
171+
"""
172+
function with(f, pair::Pair{<:ScopedValue}, rest::Pair{<:ScopedValue}...)
173+
@with(pair, rest..., f())
174+
end
175+
with(@nospecialize(f)) = f()
176+
204177
end # module ScopedValues

0 commit comments

Comments
 (0)