Skip to content

Commit bfb05d5

Browse files
vtjnashaviatesk
authored andcommitted
inference: propagate variable changes to all exception frames (#42081)
* inference: propagate variable changes to all exception frames Fix #42022 * Update test/compiler/inference.jl * Update test/compiler/inference.jl Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> * fixup! inference: propagate variable changes to all exception frames Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> (cherry picked from commit e83b317)
1 parent dc18194 commit bfb05d5

File tree

3 files changed

+156
-33
lines changed

3 files changed

+156
-33
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1706,18 +1706,16 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
17061706
slottypes = frame.slottypes
17071707
while frame.pc´´ <= n
17081708
# make progress on the active ip set
1709-
local pc::Int = frame.pc´´ # current program-counter
1709+
local pc::Int = frame.pc´´
17101710
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
17111711
#print(pc,": ",s[pc],"\n")
17121712
local pc´::Int = pc + 1 # next program-counter (after executing instruction)
17131713
if pc == frame.pc´´
1714-
# need to update pc´´ to point at the new lowest instruction in W
1715-
min_pc = _bits_findnext(W.bits, pc + 1)
1716-
frame.pc´´ = min_pc == -1 ? n + 1 : min_pc
1714+
# want to update pc´´ to point at the new lowest instruction in W
1715+
frame.pc´´ = pc´
17171716
end
17181717
delete!(W, pc)
17191718
frame.currpc = pc
1720-
frame.cur_hand = frame.handler_at[pc]
17211719
edges = frame.stmt_edges[pc]
17221720
edges === nothing || empty!(edges)
17231721
frame.stmt_info[pc] = nothing
@@ -1759,7 +1757,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
17591757
pc´ = l
17601758
else
17611759
# general case
1762-
frame.handler_at[l] = frame.cur_hand
17631760
changes_else = changes
17641761
if isa(condt, Conditional)
17651762
changes_else = conditional_changes(changes_else, condt.elsetype, condt.var)
@@ -1818,7 +1815,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18181815
end
18191816
elseif hd === :enter
18201817
l = stmt.args[1]::Int
1821-
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
18221818
# propagate type info to exception handler
18231819
old = states[l]
18241820
newstate_catch = stupdate!(old, changes)
@@ -1830,11 +1826,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18301826
states[l] = newstate_catch
18311827
end
18321828
typeassert(states[l], VarTable)
1833-
frame.handler_at[l] = frame.cur_hand
18341829
elseif hd === :leave
1835-
for i = 1:((stmt.args[1])::Int)
1836-
frame.cur_hand = (frame.cur_hand::Pair{Any,Any}).second
1837-
end
18381830
else
18391831
if hd === :(=)
18401832
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
@@ -1864,16 +1856,22 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18641856
frame.src.ssavaluetypes[pc] = t
18651857
end
18661858
end
1867-
if frame.cur_hand !== nothing && isa(changes, StateUpdate)
1868-
# propagate new type info to exception handler
1869-
# the handling for Expr(:enter) propagates all changes from before the try/catch
1870-
# so this only needs to propagate any changes
1871-
l = frame.cur_hand.first::Int
1872-
if stupdate1!(states[l]::VarTable, changes::StateUpdate) !== false
1873-
if l < frame.pc´´
1874-
frame.pc´´ = l
1859+
if isa(changes, StateUpdate)
1860+
let cur_hand = frame.handler_at[pc], l, enter
1861+
while cur_hand != 0
1862+
enter = frame.src.code[cur_hand]
1863+
l = (enter::Expr).args[1]::Int
1864+
# propagate new type info to exception handler
1865+
# the handling for Expr(:enter) propagates all changes from before the try/catch
1866+
# so this only needs to propagate any changes
1867+
if stupdate1!(states[l]::VarTable, changes::StateUpdate) !== false
1868+
if l < frame.pc´´
1869+
frame.pc´´ = l
1870+
end
1871+
push!(W, l)
1872+
end
1873+
cur_hand = frame.handler_at[cur_hand]
18751874
end
1876-
push!(W, l)
18771875
end
18781876
end
18791877
end
@@ -1886,7 +1884,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18861884
end
18871885

18881886
pc´ > n && break # can't proceed with the fast-path fall-through
1889-
frame.handler_at[pc´] = frame.cur_hand
18901887
newstate = stupdate!(states[pc´], changes)
18911888
if isa(stmt, GotoNode) && frame.pc´´ < pc´
18921889
# if we are processing a goto node anyways,
@@ -1897,7 +1894,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18971894
states[pc´] = newstate
18981895
end
18991896
push!(W, pc´)
1900-
pc = frame.pc´´
1897+
break
19011898
elseif newstate !== nothing
19021899
states[pc´] = newstate
19031900
pc = pc´
@@ -1907,6 +1904,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
19071904
break
19081905
end
19091906
end
1907+
frame.pc´´ = _bits_findnext(W.bits, frame.pc´´)::Int # next program-counter
19101908
end
19111909
frame.dont_work_on_me = false
19121910
nothing

base/compiler/inferencestate.jl

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ mutable struct InferenceState
2828
pc´´::LineNum
2929
nstmts::Int
3030
# current exception handler info
31-
cur_hand #::Union{Nothing, Pair{LineNum, prev_handler}}
32-
handler_at::Vector{Any}
33-
n_handlers::Int
31+
handler_at::Vector{LineNum}
3432
# ssavalue sparsity and restart info
3533
ssavalue_uses::Vector{BitSet}
3634
throw_blocks::BitSet
@@ -87,12 +85,9 @@ mutable struct InferenceState
8785
throw_blocks = find_throw_blocks(code)
8886

8987
# exception handlers
90-
cur_hand = nothing
91-
handler_at = Any[ nothing for i=1:n ]
92-
n_handlers = 0
93-
94-
W = BitSet()
95-
push!(W, 1) #initial pc to visit
88+
ip = BitSet()
89+
handler_at = compute_trycatch(src.code, ip)
90+
push!(ip, 1)
9691

9792
if !toplevel
9893
meth = linfo.def
@@ -103,14 +98,14 @@ mutable struct InferenceState
10398

10499
valid_worlds = WorldRange(src.min_world,
105100
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)
101+
106102
frame = new(
107103
InferenceParams(interp), result, linfo,
108104
sp, slottypes, inmodule, 0,
109105
IdSet{InferenceState}(), IdSet{InferenceState}(),
110106
src, get_world_counter(interp), valid_worlds,
111107
nargs, s_types, s_edges, stmt_info,
112-
Union{}, W, 1, n,
113-
cur_hand, handler_at, n_handlers,
108+
Union{}, ip, 1, n, handler_at,
114109
ssavalue_uses, throw_blocks,
115110
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
116111
Vector{InferenceState}(), # callers_in_cycle
@@ -124,6 +119,91 @@ mutable struct InferenceState
124119
end
125120
end
126121

122+
function compute_trycatch(code::Vector{Any}, ip::BitSet)
123+
# The goal initially is to record the frame like this for the state at exit:
124+
# 1: (enter 3) # == 0
125+
# 3: (expr) # == 1
126+
# 3: (leave 1) # == 1
127+
# 4: (expr) # == 0
128+
# then we can find all trys by walking backwards from :enter statements,
129+
# and all catches by looking at the statement after the :enter
130+
n = length(code)
131+
empty!(ip)
132+
ip.offset = 0 # for _bits_findnext
133+
push!(ip, n + 1)
134+
handler_at = fill(0, n)
135+
136+
# start from all :enter statements and record the location of the try
137+
for pc = 1:n
138+
stmt = code[pc]
139+
if isexpr(stmt, :enter)
140+
l = stmt.args[1]::Int
141+
handler_at[pc + 1] = pc
142+
push!(ip, pc + 1)
143+
handler_at[l] = pc
144+
push!(ip, l)
145+
end
146+
end
147+
148+
# now forward those marks to all :leave statements
149+
pc´´ = 0
150+
while true
151+
# make progress on the active ip set
152+
pc = _bits_findnext(ip.bits, pc´´)::Int
153+
pc > n && break
154+
while true # inner loop optimizes the common case where it can run straight from pc to pc + 1
155+
pc´ = pc + 1 # next program-counter (after executing instruction)
156+
if pc == pc´´
157+
pc´´ = pc´
158+
end
159+
delete!(ip, pc)
160+
cur_hand = handler_at[pc]
161+
@assert cur_hand != 0 "unbalanced try/catch"
162+
stmt = code[pc]
163+
if isa(stmt, GotoNode)
164+
pc´ = stmt.label
165+
elseif isa(stmt, GotoIfNot)
166+
l = stmt.dest::Int
167+
if handler_at[l] != cur_hand
168+
@assert handler_at[l] == 0 "unbalanced try/catch"
169+
handler_at[l] = cur_hand
170+
if l < pc´´
171+
pc´´ = l
172+
end
173+
push!(ip, l)
174+
end
175+
elseif isa(stmt, ReturnNode)
176+
@assert !isdefined(stmt, :val) "unbalanced try/catch"
177+
break
178+
elseif isa(stmt, Expr)
179+
head = stmt.head
180+
if head === :enter
181+
cur_hand = pc
182+
elseif head === :leave
183+
l = stmt.args[1]::Int
184+
for i = 1:l
185+
cur_hand = handler_at[cur_hand]
186+
end
187+
cur_hand == 0 && break
188+
end
189+
end
190+
191+
pc´ > n && break # can't proceed with the fast-path fall-through
192+
if handler_at[pc´] != cur_hand
193+
@assert handler_at[pc´] == 0 "unbalanced try/catch"
194+
handler_at[pc´] = cur_hand
195+
elseif !in(pc´, ip)
196+
break # already visited
197+
end
198+
pc = pc´
199+
end
200+
end
201+
202+
@assert first(ip) == n + 1
203+
return handler_at
204+
end
205+
206+
127207
"""
128208
Iterate through all callers of the given InferenceState in the abstract
129209
interpretation stack (including the given InferenceState itself), vising

test/compiler/inference.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3454,3 +3454,48 @@ end
34543454
f41908(x::Complex{T}) where {String<:T<:String} = 1
34553455
g41908() = f41908(Any[1][1])
34563456
@test only(Base.return_types(g41908, ())) <: Int
3457+
3458+
# issue #42022
3459+
let x = Tuple{Int,Any}[
3460+
#= 1=# (0, Expr(:(=), Core.SlotNumber(3), 1))
3461+
#= 2=# (0, Expr(:enter, 18))
3462+
#= 3=# (2, Expr(:(=), Core.SlotNumber(3), 2.0))
3463+
#= 4=# (2, Expr(:enter, 12))
3464+
#= 5=# (4, Expr(:(=), Core.SlotNumber(3), '3'))
3465+
#= 6=# (4, Core.GotoIfNot(Core.SlotNumber(2), 9))
3466+
#= 7=# (4, Expr(:leave, 2))
3467+
#= 8=# (0, Core.ReturnNode(1))
3468+
#= 9=# (4, Expr(:call, GlobalRef(Main, :throw)))
3469+
#=10=# (4, Expr(:leave, 1))
3470+
#=11=# (2, Core.GotoNode(16))
3471+
#=12=# (4, Expr(:leave, 1))
3472+
#=13=# (2, Expr(:(=), Core.SlotNumber(4), Expr(:the_exception)))
3473+
#=14=# (2, Expr(:call, GlobalRef(Main, :rethrow)))
3474+
#=15=# (2, Expr(:pop_exception, Core.SSAValue(4)))
3475+
#=16=# (2, Expr(:leave, 1))
3476+
#=17=# (0, Core.GotoNode(22))
3477+
#=18=# (2, Expr(:leave, 1))
3478+
#=19=# (0, Expr(:(=), Core.SlotNumber(5), Expr(:the_exception)))
3479+
#=20=# (0, nothing)
3480+
#=21=# (0, Expr(:pop_exception, Core.SSAValue(2)))
3481+
#=22=# (0, Core.ReturnNode(Core.SlotNumber(3)))
3482+
]
3483+
handler_at = Core.Compiler.compute_trycatch(last.(x), Core.Compiler.BitSet())
3484+
@test handler_at == first.(x)
3485+
end
3486+
3487+
@test only(Base.return_types((Bool,)) do y
3488+
x = 1
3489+
try
3490+
x = 2.0
3491+
try
3492+
x = '3'
3493+
y ? (return 1) : throw()
3494+
catch ex1
3495+
rethrow()
3496+
end
3497+
catch ex2
3498+
nothing
3499+
end
3500+
return x
3501+
end) === Union{Int, Float64, Char}

0 commit comments

Comments
 (0)