Skip to content

Commit d6d5208

Browse files
authored
use ReturnNode, GotoIfNot, and Argument more consistently in IR (#36318)
1 parent 5142abf commit d6d5208

36 files changed

+354
-296
lines changed

base/boot.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,15 @@
103103
# label::Int
104104
#end
105105

106+
#struct GotoIfNot
107+
# cond::Any
108+
# dest::Int
109+
#end
110+
111+
#struct ReturnNode
112+
# val::Any
113+
#end
114+
106115
#struct PiNode
107116
# val
108117
# typ
@@ -368,6 +377,10 @@ _new(:GotoNode, :Int)
368377
_new(:NewvarNode, :SlotNumber)
369378
_new(:QuoteNode, :Any)
370379
_new(:SSAValue, :Int)
380+
_new(:Argument, :Int)
381+
_new(:ReturnNode, :Any)
382+
eval(Core, :(ReturnNode() = $(Expr(:new, :ReturnNode)))) # unassigned val indicates unreachable
383+
eval(Core, :(GotoIfNot(@nospecialize(cond), dest::Int) = $(Expr(:new, :GotoIfNot, :cond, :dest))))
371384
eval(Core, :(LineNumberNode(l::Int) = $(Expr(:new, :LineNumberNode, :l, nothing))))
372385
eval(Core, :(LineNumberNode(l::Int, @nospecialize(f)) = $(Expr(:new, :LineNumberNode, :l, :f))))
373386
LineNumberNode(l::Int, f::String) = LineNumberNode(l, Symbol(f))
@@ -453,12 +466,12 @@ Symbol(s::Symbol) = s
453466

454467
# module providing the IR object model
455468
module IR
456-
export CodeInfo, MethodInstance, CodeInstance, GotoNode,
457-
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot,
469+
export CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
470+
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot, Argument,
458471
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode
459472

460-
import Core: CodeInfo, MethodInstance, CodeInstance, GotoNode,
461-
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot,
473+
import Core: CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
474+
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot, Argument,
462475
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode
463476

464477
end

base/compiler/abstractinterpretation.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,13 +1174,13 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
11741174
changes[sn] = VarState(Bottom, true)
11751175
elseif isa(stmt, GotoNode)
11761176
pc´ = (stmt::GotoNode).label
1177-
elseif hd === :gotoifnot
1178-
condt = abstract_eval(interp, stmt.args[1], s[pc], frame)
1177+
elseif isa(stmt, GotoIfNot)
1178+
condt = abstract_eval(interp, stmt.cond, s[pc], frame)
11791179
if condt === Bottom
11801180
break
11811181
end
11821182
condval = maybe_extract_const_bool(condt)
1183-
l = stmt.args[2]::Int
1183+
l = stmt.dest::Int
11841184
# constant conditions
11851185
if condval === true
11861186
elseif condval === false
@@ -1207,9 +1207,9 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
12071207
s[l] = newstate_else
12081208
end
12091209
end
1210-
elseif hd === :return
1210+
elseif isa(stmt, ReturnNode)
12111211
pc´ = n + 1
1212-
rt = widenconditional(abstract_eval(interp, stmt.args[1], s[pc], frame))
1212+
rt = widenconditional(abstract_eval(interp, stmt.val, s[pc], frame))
12131213
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct)
12141214
# only propagate information we know we can store
12151215
# and is valid inter-procedurally

base/compiler/optimize.jl

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -339,12 +339,6 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
339339
# prevent inlining.
340340
extyp = line == -1 ? Any : src.ssavaluetypes[line]
341341
return extyp === Union{} ? 0 : 20
342-
elseif head === :return
343-
a = ex.args[1]
344-
if a isa Expr
345-
return statement_cost(a, -1, src, sptypes, slottypes, params)
346-
end
347-
return 0
348342
elseif head === :(=)
349343
if ex.args[1] isa GlobalRef
350344
cost = 20
@@ -364,12 +358,6 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
364358
# since these aren't usually performance-sensitive functions,
365359
# and llvm is more likely to miscompile them when these functions get large
366360
return typemax(Int)
367-
elseif head === :gotoifnot
368-
target = ex.args[2]::Int
369-
# loops are generally always expensive
370-
# but assume that forward jumps are already counted for from
371-
# summing the cost of the not-taken branch
372-
return target < line ? 40 : 0
373361
end
374362
return 0
375363
end
@@ -386,6 +374,8 @@ function inline_worthy(body::Array{Any,1}, src::CodeInfo, sptypes::Vector{Any},
386374
# but assume that forward jumps are already counted for from
387375
# summing the cost of the not-taken branch
388376
thiscost = stmt.label < line ? 40 : 0
377+
elseif stmt isa GotoIfNot
378+
thiscost = stmt.dest < line ? 40 : 0
389379
else
390380
continue
391381
end
@@ -423,20 +413,23 @@ function renumber_ir_elements!(body::Vector{Any}, ssachangemap::Vector{Int}, lab
423413
el = body[i]
424414
if isa(el, GotoNode)
425415
body[i] = GotoNode(el.label + labelchangemap[el.label])
416+
elseif isa(el, GotoIfNot)
417+
cond = el.cond
418+
if isa(cond, SSAValue)
419+
cond = SSAValue(cond.id + ssachangemap[cond.id])
420+
end
421+
body[i] = GotoIfNot(cond, el.dest + labelchangemap[el.dest])
422+
elseif isa(el, ReturnNode)
423+
if isdefined(el, :val) && isa(el.val, SSAValue)
424+
body[i] = ReturnNode(SSAValue(el.val.id + ssachangemap[el.val.id]))
425+
end
426426
elseif isa(el, SSAValue)
427427
body[i] = SSAValue(el.id + ssachangemap[el.id])
428428
elseif isa(el, Expr)
429429
if el.head === :(=) && el.args[2] isa Expr
430430
el = el.args[2]::Expr
431431
end
432-
if el.head === :gotoifnot
433-
cond = el.args[1]
434-
if isa(cond, SSAValue)
435-
el.args[1] = SSAValue(cond.id + ssachangemap[cond.id])
436-
end
437-
tgt = el.args[2]::Int
438-
el.args[2] = tgt + labelchangemap[tgt]
439-
elseif el.head === :enter
432+
if el.head === :enter
440433
tgt = el.args[1]::Int
441434
el.args[1] = tgt + labelchangemap[tgt]
442435
elseif !is_meta_expr_head(el.head)

base/compiler/ssair/driver.jl

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,6 @@ include("compiler/ssair/verify.jl")
2020
include("compiler/ssair/legacy.jl")
2121
#@isdefined(Base) && include("compiler/ssair/show.jl")
2222

23-
function normalize_expr(stmt::Expr)
24-
if stmt.head === :gotoifnot
25-
return GotoIfNot(stmt.args[1], stmt.args[2]::Int)
26-
elseif stmt.head === :return
27-
return (length(stmt.args) == 0) ? ReturnNode(nothing) : ReturnNode(stmt.args[1])
28-
elseif stmt.head === :unreachable
29-
return ReturnNode()
30-
else
31-
return stmt
32-
end
33-
end
34-
3523
function normalize(@nospecialize(stmt), meta::Vector{Any})
3624
if isa(stmt, Expr)
3725
if stmt.head === :meta
@@ -40,10 +28,6 @@ function normalize(@nospecialize(stmt), meta::Vector{Any})
4028
push!(meta, stmt)
4129
end
4230
return nothing
43-
elseif stmt.head === :line
44-
return nothing # deprecated - we shouldn't encounter this
45-
else
46-
return normalize_expr(stmt)
4731
end
4832
end
4933
return stmt
@@ -72,7 +56,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
7256
prevloc = codeloc
7357
end
7458
if code[idx] isa Expr && ci.ssavaluetypes[idx] === Union{}
75-
if !(idx < length(code) && isexpr(code[idx + 1], :unreachable))
59+
if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val))
7660
# insert unreachable in the same basic block after the current instruction (splitting it)
7761
insert!(code, idx + 1, ReturnNode())
7862
insert!(ci.codelocs, idx + 1, ci.codelocs[idx])

base/compiler/ssair/ir.jl

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,6 @@
44
@eval Core.UpsilonNode() = $(Expr(:new, Core.UpsilonNode))
55
Core.PhiNode() = Core.PhiNode(Any[], Any[])
66

7-
struct Argument
8-
n::Int
9-
end
10-
11-
struct GotoIfNot
12-
cond::Any
13-
dest::Int
14-
GotoIfNot(@nospecialize(cond), dest::Int) = new(cond, dest)
15-
end
16-
17-
struct ReturnNode
18-
val::Any
19-
ReturnNode(@nospecialize(val)) = new(val)
20-
# unassigned val indicates unreachable
21-
ReturnNode() = new()
22-
end
23-
247
"""
258
Like UnitRange{Int}, but can handle the `last` field, being temporarily
269
< first (this can happen during compacting)
@@ -85,14 +68,6 @@ function basic_blocks_starts(stmts::Vector{Any})
8568
push!(jump_dests, idx+1)
8669
# The catch block is a jump dest
8770
push!(jump_dests, stmt.args[1]::Int)
88-
elseif stmt.head === :gotoifnot
89-
# also tolerate expr form of IR
90-
push!(jump_dests, idx+1)
91-
push!(jump_dests, stmt.args[2]::Int)
92-
elseif stmt.head === :return
93-
# also tolerate expr form of IR
94-
# This is a fake dest to force the next stmt to start a bb
95-
idx < length(stmts) && push!(jump_dests, idx+1)
9671
end
9772
end
9873
if isa(stmt, PhiNode)
@@ -129,7 +104,7 @@ function compute_basic_blocks(stmts::Vector{Any})
129104
# Compute successors/predecessors
130105
for (num, b) in enumerate(blocks)
131106
terminator = stmts[last(b.stmts)]
132-
if isa(terminator, ReturnNode) || isexpr(terminator, :return)
107+
if isa(terminator, ReturnNode)
133108
# return never has any successors
134109
continue
135110
end
@@ -161,15 +136,6 @@ function compute_basic_blocks(stmts::Vector{Any})
161136
push!(blocks[block′].preds, num)
162137
push!(blocks[block′].preds, 0)
163138
push!(b.succs, block′)
164-
elseif terminator.head === :gotoifnot
165-
block′ = block_for_inst(basic_block_index, terminator.args[2]::Int)
166-
if block′ == num + 1
167-
# This GotoIfNot acts like a noop - treat it as such.
168-
# We will drop it during SSA renaming
169-
else
170-
push!(blocks[block′].preds, num)
171-
push!(b.succs, block′)
172-
end
173139
end
174140
end
175141
# statement fall-through
@@ -396,8 +362,7 @@ function is_relevant_expr(e::Expr)
396362
:gc_preserve_begin, :gc_preserve_end,
397363
:foreigncall, :isdefined, :copyast,
398364
:undefcheck, :throw_undef_if_not,
399-
:cfunction, :method, :pop_exception,
400-
#=legacy IR format support=# :gotoifnot, :return)
365+
:cfunction, :method, :pop_exception)
401366
end
402367

403368
function setindex!(x::UseRef, @nospecialize(v))

base/compiler/ssair/legacy.jl

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,9 @@ end
1212

1313
function inflate_ir(ci::CodeInfo, sptypes::Vector{Any}, argtypes::Vector{Any})
1414
code = copy_exprargs(ci.code) # TODO: this is a huge hot-spot
15-
for i = 1:length(code)
16-
if isa(code[i], Expr)
17-
code[i] = normalize_expr(code[i])
18-
end
19-
end
2015
cfg = compute_basic_blocks(code)
2116
for i = 1:length(code)
2217
stmt = code[i]
23-
urs = userefs(stmt)
24-
for op in urs
25-
val = op[]
26-
if isa(val, SlotNumber)
27-
op[] = Argument(val.id)
28-
end
29-
end
30-
stmt = urs[]
3118
# Translate statement edges to bb_edges
3219
if isa(stmt, GotoNode)
3320
code[i] = GotoNode(block_for_inst(cfg, stmt.label))
@@ -67,26 +54,12 @@ function replace_code_newstyle!(ci::CodeInfo, ir::IRCode, nargs::Int)
6754
# (and undo normalization for now)
6855
for i = 1:length(ci.code)
6956
stmt = ci.code[i]
70-
urs = userefs(stmt)
71-
for op in urs
72-
val = op[]
73-
if isa(val, Argument)
74-
op[] = SlotNumber(val.n)
75-
end
76-
end
77-
stmt = urs[]
7857
if isa(stmt, GotoNode)
7958
stmt = GotoNode(first(ir.cfg.blocks[stmt.label].stmts))
8059
elseif isa(stmt, GotoIfNot)
81-
stmt = Expr(:gotoifnot, stmt.cond, first(ir.cfg.blocks[stmt.dest].stmts))
60+
stmt = GotoIfNot(stmt.cond, first(ir.cfg.blocks[stmt.dest].stmts))
8261
elseif isa(stmt, PhiNode)
8362
stmt = PhiNode(Any[last(ir.cfg.blocks[edge::Int].stmts) for edge in stmt.edges], stmt.values)
84-
elseif isa(stmt, ReturnNode)
85-
if isdefined(stmt, :val)
86-
stmt = Expr(:return, stmt.val)
87-
else
88-
stmt = Expr(:unreachable)
89-
end
9063
elseif isa(stmt, Expr) && stmt.head === :enter
9164
stmt.args[1] = first(ir.cfg.blocks[stmt.args[1]::Int].stmts)
9265
end

base/compiler/ssair/show.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ end
141141

142142
function should_print_ssa_type(@nospecialize node)
143143
if isa(node, Expr)
144-
return !(node.head in (:gc_preserve_begin, :gc_preserve_end, :meta, :return, :enter, :leave))
144+
return !(node.head in (:gc_preserve_begin, :gc_preserve_end, :meta, :enter, :leave))
145145
end
146146
return !isa(node, PiNode) && !isa(node, GotoIfNot) &&
147147
!isa(node, GotoNode) && !isa(node, ReturnNode) &&
@@ -722,9 +722,7 @@ function show_ir_stmt(io::IO, code::CodeInfo, idx::Int, line_info_preprinter, li
722722
print(io, inlining_indent, " ")
723723
# convert statement index to labels, as expected by print_stmt
724724
if stmt isa Expr
725-
if stmt.head === :gotoifnot && length(stmt.args) == 2 && stmt.args[2] isa Int
726-
stmt = GotoIfNot(stmt.args[1], block_for_inst(cfg, stmt.args[2]::Int))
727-
elseif stmt.head === :enter && length(stmt.args) == 1 && stmt.args[1] isa Int
725+
if stmt.head === :enter && length(stmt.args) == 1 && stmt.args[1] isa Int
728726
stmt = Expr(:enter, block_for_inst(cfg, stmt.args[1]::Int))
729727
end
730728
elseif isa(stmt, GotoIfNot)

base/compiler/ssair/slot2ssa.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ function lift_defuse(cfg::CFG, defuse)
4343
end
4444
end
4545

46-
@inline slot_id(s) = isa(s, SlotNumber) ? (s::SlotNumber).id : (s::TypedSlot).id
4746
function scan_slot_def_use(nargs::Int, ci::CodeInfo, code::Vector{Any})
4847
nslots = length(ci.slotflags)
4948
result = SlotInfo[SlotInfo() for i = 1:nslots]
@@ -821,7 +820,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse, narg
821820
elseif isexpr(stmt, :enter)
822821
new_code[idx] = Expr(:enter, block_for_inst(cfg, stmt.args[1]))
823822
ssavalmap[idx] = SSAValue(idx) # Slot to store token for pop_exception
824-
elseif isexpr(stmt, :leave) || isexpr(stmt, :(=)) || isexpr(stmt, :return) ||
823+
elseif isexpr(stmt, :leave) || isexpr(stmt, :(=)) || isa(stmt, ReturnNode) ||
825824
isexpr(stmt, :meta) || isa(stmt, NewvarNode)
826825
new_code[idx] = stmt
827826
else

base/compiler/typeinfer.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,15 @@ function annotate_slot_load!(e::Expr, vtypes::VarTable, sv::InferenceState, unde
251251
end
252252
end
253253

254+
function annotate_slot_load(@nospecialize(e), vtypes::VarTable, sv::InferenceState, undefs::Array{Bool,1})
255+
if isa(e, Expr)
256+
annotate_slot_load!(e, vtypes, sv, undefs)
257+
elseif isa(e, Slot)
258+
return visit_slot_load!(e, vtypes, sv, undefs)
259+
end
260+
return e
261+
end
262+
254263
function visit_slot_load!(sl::Slot, vtypes::VarTable, sv::InferenceState, undefs::Array{Bool,1})
255264
id = slot_id(sl)
256265
s = vtypes[id]
@@ -330,13 +339,12 @@ function type_annotate!(sv::InferenceState)
330339
body = src.code::Array{Any,1}
331340
nexpr = length(body)
332341

333-
# replace gotoifnot with its condition if the branch target is unreachable
342+
# replace GotoIfNot with its condition if the branch target is unreachable
334343
for i = 1:nexpr
335344
expr = body[i]
336-
if isa(expr, Expr) && expr.head === :gotoifnot
337-
tgt = expr.args[2]::Int
338-
if !isa(states[tgt], VarTable)
339-
body[i] = expr.args[1]
345+
if isa(expr, GotoIfNot)
346+
if !isa(states[expr.dest], VarTable)
347+
body[i] = expr.cond
340348
end
341349
end
342350
end
@@ -353,6 +361,10 @@ function type_annotate!(sv::InferenceState)
353361
# st_i === nothing => unreached statement (see issue #7836)
354362
if isa(expr, Expr)
355363
annotate_slot_load!(expr, st_i, sv, undefs)
364+
elseif isa(expr, ReturnNode) && isdefined(expr, :val)
365+
body[i] = ReturnNode(annotate_slot_load(expr.val, st_i, sv, undefs))
366+
elseif isa(expr, GotoIfNot)
367+
body[i] = GotoIfNot(annotate_slot_load(expr.cond, st_i, sv, undefs), expr.dest)
356368
elseif isa(expr, Slot)
357369
body[i] = visit_slot_load!(expr, st_i, sv, undefs)
358370
end
@@ -549,7 +561,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
549561
if invoke_api(code) == 2
550562
i == 2 && ccall(:jl_typeinf_end, Cvoid, ())
551563
tree = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
552-
tree.code = Any[ Expr(:return, quoted(code.rettype_const)) ]
564+
tree.code = Any[ ReturnNode(quoted(code.rettype_const)) ]
553565
nargs = Int(method.nargs)
554566
tree.slotnames = ccall(:jl_uncompress_argnames, Vector{Symbol}, (Any,), method.slot_syms)
555567
tree.slotflags = fill(0x00, nargs)

0 commit comments

Comments
 (0)