Skip to content

Commit 1bc7f43

Browse files
authored
improve many type stabilities in Core.Compiler.typeinf (#39549)
All of them are detected by JET.jl's self-profiling. The following code will print type-instabilities/type-errors for all code paths reachable from `typeinf(::NativeInterpreter, ::InferenceState)`. ```julia julia> using JET julia> report_call(Core.Compiler.typeinf, (Core.Compiler.NativeInterpreter, Core.Compiler.InferenceState); annotate_types = true) ``` The remaining error reports (e.g. `variable Core.Compiler.string is not defined`) are because of missing functionality on error paths.
1 parent b1fbe7f commit 1bc7f43

File tree

12 files changed

+65
-49
lines changed

12 files changed

+65
-49
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
459459
# Under direct self-recursion, permit much greater use of reducers.
460460
# here we assume that complexity(specTypes) :>= complexity(sig)
461461
comparison = sv.linfo.specTypes
462-
l_comparison = length(unwrap_unionall(comparison).parameters)
462+
l_comparison = length(unwrap_unionall(comparison).parameters)::Int
463463
spec_len = max(spec_len, l_comparison)
464464
else
465465
comparison = method.sig
@@ -700,16 +700,20 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
700700
res = Union{}
701701
nargs = length(aargtypes)
702702
splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM
703-
ctypes = Any[Any[aft]]
703+
ctypes = [Any[aft]]
704704
infos = [Union{Nothing, AbstractIterationInfo}[]]
705705
for i = 1:nargs
706-
ctypes´ = []
707-
infos′ = []
706+
ctypes´ = Vector{Any}[]
707+
infos′ = Vector{Union{Nothing, AbstractIterationInfo}}[]
708708
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
709709
if !isvarargtype(ti)
710-
cti, info = precise_container_type(interp, itft, ti, sv)
710+
cti_info = precise_container_type(interp, itft, ti, sv)
711+
cti = cti_info[1]::Vector{Any}
712+
info = cti_info[2]::Union{Nothing,AbstractIterationInfo}
711713
else
712-
cti, info = precise_container_type(interp, itft, unwrapva(ti), sv)
714+
cti_info = precise_container_type(interp, itft, unwrapva(ti), sv)
715+
cti = cti_info[1]::Vector{Any}
716+
info = cti_info[2]::Union{Nothing,AbstractIterationInfo}
713717
# We can't represent a repeating sequence of the same types,
714718
# so tmerge everything together to get one type that represents
715719
# everything.
@@ -726,7 +730,7 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
726730
continue
727731
end
728732
for j = 1:length(ctypes)
729-
ct = ctypes[j]
733+
ct = ctypes[j]::Vector{Any}
730734
if isvarargtype(ct[end])
731735
# This is vararg, we're not gonna be able to do any inling,
732736
# drop the info
@@ -850,7 +854,8 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
850854
(a3 = argtypes[3]; isa(a3, Const)) && (idx = a3.val; isa(idx, Int)) &&
851855
(a2 = argtypes[2]; a2 Tuple)
852856
# TODO: why doesn't this use the getfield_tfunc?
853-
cti, _ = precise_container_type(interp, iterate, a2, sv)
857+
cti_info = precise_container_type(interp, iterate, a2, sv)
858+
cti = cti_info[1]::Vector{Any}
854859
if 1 <= idx <= length(cti)
855860
rt = unwrapva(cti[idx])
856861
end
@@ -1392,7 +1397,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
13921397
delete!(W, pc)
13931398
frame.currpc = pc
13941399
frame.cur_hand = frame.handler_at[pc]
1395-
frame.stmt_edges[pc] === nothing || empty!(frame.stmt_edges[pc])
1400+
edges = frame.stmt_edges[pc]
1401+
edges === nothing || empty!(edges)
13961402
stmt = frame.src.code[pc]
13971403
changes = s[pc]::VarTable
13981404
t = nothing
@@ -1405,7 +1411,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14051411
elseif isa(stmt, GotoNode)
14061412
pc´ = (stmt::GotoNode).label
14071413
elseif isa(stmt, GotoIfNot)
1408-
condt = abstract_eval_value(interp, stmt.cond, s[pc], frame)
1414+
condt = abstract_eval_value(interp, stmt.cond, changes, frame)
14091415
if condt === Bottom
14101416
empty!(frame.pclimitations)
14111417
end
@@ -1438,7 +1444,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14381444
end
14391445
end
14401446
newstate_else = stupdate!(s[l], changes_else)
1441-
if newstate_else !== false
1447+
if newstate_else !== nothing
14421448
# add else branch to active IP list
14431449
if l < frame.pc´´
14441450
frame.pc´´ = l
@@ -1449,7 +1455,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14491455
end
14501456
elseif isa(stmt, ReturnNode)
14511457
pc´ = n + 1
1452-
rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame))
1458+
rt = widenconditional(abstract_eval_value(interp, stmt.val, changes, frame))
14531459
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) && !isa(rt, PartialOpaque)
14541460
# only propagate information we know we can store
14551461
# and is valid inter-procedurally
@@ -1483,9 +1489,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14831489
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
14841490
# propagate type info to exception handler
14851491
old = s[l]
1486-
new = s[pc]::VarTable
1487-
newstate_catch = stupdate!(old, new)
1488-
if newstate_catch !== false
1492+
newstate_catch = stupdate!(old, changes)
1493+
if newstate_catch !== nothing
14891494
if l < frame.pc´´
14901495
frame.pc´´ = l
14911496
end
@@ -1556,12 +1561,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
15561561
# (such as a terminator for a loop, if-else, or try block),
15571562
# consider whether we should jump to an older backedge first,
15581563
# to try to traverse the statements in approximate dominator order
1559-
if newstate !== false
1564+
if newstate !== nothing
15601565
s[pc´] = newstate
15611566
end
15621567
push!(W, pc´)
15631568
pc = frame.pc´´
1564-
elseif newstate !== false
1569+
elseif newstate !== nothing
15651570
s[pc´] = newstate
15661571
pc = pc´
15671572
elseif pc´ in W

base/compiler/optimize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
371371
end
372372
a = ex.args[2]
373373
if a isa Expr
374-
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, params, error_path))
374+
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, union_penalties, params, error_path))
375375
end
376376
return cost
377377
elseif head === :copyast
@@ -392,7 +392,7 @@ function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::CodeInfo,
392392
thiscost = 0
393393
if stmt isa Expr
394394
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, union_penalties, params,
395-
params.unoptimize_throw_blocks && line in throw_blocks)::Int
395+
throw_blocks !== nothing && line in throw_blocks)::Int
396396
elseif stmt isa GotoNode
397397
# loops are generally always expensive
398398
# but assume that forward jumps are already counted for from

base/compiler/ssair/driver.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,14 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
4343
labelmap = coverage ? fill(0, length(code)) : changemap
4444
prevloc = zero(eltype(ci.codelocs))
4545
stmtinfo = sv.stmt_info
46+
ssavaluetypes = ci.ssavaluetypes::Vector{Any}
4647
while idx <= length(code)
4748
codeloc = ci.codelocs[idx]
4849
if coverage && codeloc != prevloc && codeloc != 0
4950
# insert a side-effect instruction before the current instruction in the same basic block
5051
insert!(code, idx, Expr(:code_coverage_effect))
5152
insert!(ci.codelocs, idx, codeloc)
52-
insert!(ci.ssavaluetypes, idx, Nothing)
53+
insert!(ssavaluetypes, idx, Nothing)
5354
insert!(stmtinfo, idx, nothing)
5455
changemap[oldidx] += 1
5556
if oldidx < length(labelmap)
@@ -58,12 +59,12 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
5859
idx += 1
5960
prevloc = codeloc
6061
end
61-
if code[idx] isa Expr && ci.ssavaluetypes[idx] === Union{}
62+
if code[idx] isa Expr && ssavaluetypes[idx] === Union{}
6263
if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val))
6364
# insert unreachable in the same basic block after the current instruction (splitting it)
6465
insert!(code, idx + 1, ReturnNode())
6566
insert!(ci.codelocs, idx + 1, ci.codelocs[idx])
66-
insert!(ci.ssavaluetypes, idx + 1, Union{})
67+
insert!(ssavaluetypes, idx + 1, Union{})
6768
insert!(stmtinfo, idx + 1, nothing)
6869
if oldidx < length(changemap)
6970
changemap[oldidx + 1] += 1

base/compiler/ssair/inlining.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
630630
call = thisarginfo.each[i]
631631
new_stmt = Expr(:call, argexprs[2], def, state...)
632632
state1 = insert_node!(ir, idx, call.rt, new_stmt)
633-
new_sig = with_atype(call_sig(ir, new_stmt))
633+
new_sig = with_atype(call_sig(ir, new_stmt)::Signature)
634634
if isa(call.info, MethodMatchInfo) || isa(call.info, UnionSplitInfo)
635635
info = isa(call.info, MethodMatchInfo) ?
636636
MethodMatchInfo[call.info] : call.info.matches
@@ -680,7 +680,7 @@ function resolve_todo(todo::InliningTodo, et::Union{EdgeTracker, Nothing}, cache
680680
spec = todo.spec::DelayedInliningSpec
681681
isconst, src = find_inferred(todo.mi, spec.atypes, caches, spec.stmttype)
682682

683-
if isconst
683+
if isconst && et !== nothing
684684
push!(et, todo.mi)
685685
return ConstantCase(src)
686686
end
@@ -988,9 +988,12 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok
988988
sig.atype, method.sig)::SimpleVector
989989
methsp = methsp::SimpleVector
990990
match = MethodMatch(metharg, methsp, method, true)
991-
result = analyze_method!(match, sig.atypes, state.et, state.caches, state.params, calltype)
991+
et = state.et
992+
result = analyze_method!(match, sig.atypes, et, state.caches, state.params, calltype)
992993
handle_single_case!(ir, stmt, idx, result, true, todo)
993-
intersect!(state.et, WorldRange(invoke_data.min_valid, invoke_data.max_valid))
994+
if et !== nothing
995+
intersect!(et, WorldRange(invoke_data.min_valid, invoke_data.max_valid))
996+
end
994997
return nothing
995998
end
996999

@@ -1118,6 +1121,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
11181121
sig.atype, only_method.sig)::SimpleVector
11191122
match = MethodMatch(metharg, methsp, only_method, true)
11201123
else
1124+
meth = meth::MethodLookupResult
11211125
@assert length(meth) == 1
11221126
match = meth[1]
11231127
end
@@ -1145,6 +1149,8 @@ end
11451149
function assemble_inline_todo!(ir::IRCode, state::InliningState)
11461150
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie)
11471151
todo = Pair{Int, Any}[]
1152+
et = state.et
1153+
method_table = state.method_table
11481154
for idx in 1:length(ir.stmts)
11491155
r = process_simple!(ir, todo, idx, state)
11501156
r === nothing && continue
@@ -1176,20 +1182,18 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
11761182
nu = unionsplitcost(sig.atypes)
11771183
if nu == 1 || nu > state.params.MAX_UNION_SPLITTING
11781184
if !isa(info, MethodMatchInfo)
1179-
if state.method_table === nothing
1180-
continue
1181-
end
1182-
info = recompute_method_matches(sig.atype, state.params, state.et, state.method_table)
1185+
method_table === nothing && continue
1186+
et === nothing && continue
1187+
info = recompute_method_matches(sig.atype, state.params, et, method_table)
11831188
end
11841189
infos = MethodMatchInfo[info]
11851190
else
11861191
if !isa(info, UnionSplitInfo)
1187-
if state.method_table === nothing
1188-
continue
1189-
end
1192+
method_table === nothing && continue
1193+
et === nothing && continue
11901194
infos = MethodMatchInfo[]
11911195
for union_sig in UnionSplitSignature(sig.atypes)
1192-
push!(infos, recompute_method_matches(argtypes_to_type(union_sig), state.params, state.et, state.method_table))
1196+
push!(infos, recompute_method_matches(argtypes_to_type(union_sig), state.params, et, method_table))
11931197
end
11941198
else
11951199
infos = info.matches

base/compiler/ssair/ir.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,15 @@ function compute_basic_blocks(stmts::Vector{Any})
139139
return CFG(blocks, basic_block_index)
140140
end
141141

142+
# this function assumes insert position exists
142143
function first_insert_for_bb(code, cfg::CFG, block::Int)
143144
for idx in cfg.blocks[block].stmts
144145
stmt = code[idx]
145146
if !isa(stmt, PhiNode)
146147
return idx
147148
end
148149
end
150+
error("any insert position isn't found")
149151
end
150152

151153
# SSA-indexed nodes
@@ -893,7 +895,7 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::
893895
# Check if the block is now dead
894896
if length(preds) == 0
895897
for succ in copy(compact.result_bbs[compact.bb_rename_succ[to]].succs)
896-
kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename_pred))
898+
kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename_pred)::Int)
897899
end
898900
if to < active_bb
899901
# Kill all statements in the block

base/compiler/ssair/slot2ssa.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse, narg
764764
# Having undef_token appear on the RHS is possible if we're on a dead branch.
765765
# Do something reasonable here, by marking the LHS as undef as well.
766766
if val !== undef_token
767-
incoming_vals[id] = SSAValue(make_ssa!(ci, code, idx, id, typ))
767+
incoming_vals[id] = SSAValue(make_ssa!(ci, code, idx, id, typ)::Int)
768768
else
769769
code[idx] = nothing
770770
incoming_vals[id] = undef_token

base/compiler/ssair/verify.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ end
1414
function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int, use_idx::Int, print::Bool)
1515
if isa(op, SSAValue)
1616
if op.id > length(ir.stmts)
17-
def_bb = block_for_inst(ir.cfg, ir.new_nodes[op.id - length(ir.stmts)].pos)
17+
def_bb = block_for_inst(ir.cfg, ir.new_nodes.info[op.id - length(ir.stmts)].pos)
1818
else
1919
def_bb = block_for_inst(ir.cfg, op.id)
2020
end
2121
if (def_bb == use_bb)
2222
if op.id > length(ir.stmts)
23-
@assert ir.new_nodes[op.id - length(ir.stmts)].pos <= use_idx
23+
@assert ir.new_nodes.info[op.id - length(ir.stmts)].pos <= use_idx
2424
else
2525
if op.id >= use_idx
2626
@verify_error "Def ($(op.id)) does not dominate use ($(use_idx)) in same BB"

base/compiler/typeinfer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -715,8 +715,8 @@ function merge_call_chain!(parent::InferenceState, ancestor::InferenceState, chi
715715
add_cycle_backedge!(child, parent, parent.currpc)
716716
union_caller_cycle!(ancestor, child)
717717
child = parent
718-
parent = child.parent
719718
child === ancestor && break
719+
parent = child.parent::InferenceState
720720
end
721721
end
722722

base/compiler/typelattice.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ function stupdate!(state::VarTable, changes::StateUpdate)
317317
if !isa(changes.var, Slot)
318318
return stupdate!(state, changes.state)
319319
end
320-
newstate = false
320+
newstate = nothing
321321
changeid = slot_id(changes.var::Slot)
322322
for i = 1:length(state)
323323
if i == changeid
@@ -346,7 +346,7 @@ function stupdate!(state::VarTable, changes::StateUpdate)
346346
end
347347

348348
function stupdate!(state::VarTable, changes::VarTable)
349-
newstate = false
349+
newstate = nothing
350350
for i = 1:length(state)
351351
newtype = changes[i]
352352
oldtype = state[i]
@@ -360,7 +360,7 @@ end
360360

361361
stupdate!(state::Nothing, changes::VarTable) = copy(changes)
362362

363-
stupdate!(state::Nothing, changes::Nothing) = false
363+
stupdate!(state::Nothing, changes::Nothing) = nothing
364364

365365
function stupdate1!(state::VarTable, change::StateUpdate)
366366
if !isa(change.var, Slot)

base/compiler/typeutils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,18 +204,18 @@ end
204204
# unioncomplexity estimates the number of calls to `tmerge` to obtain the given type by
205205
# counting the Union instances, taking also into account those hidden in a Tuple or UnionAll
206206
function unioncomplexity(u::Union)
207-
return unioncomplexity(u.a) + unioncomplexity(u.b) + 1
207+
return unioncomplexity(u.a)::Int + unioncomplexity(u.b)::Int + 1
208208
end
209209
function unioncomplexity(t::DataType)
210210
t.name === Tuple.name || isvarargtype(t) || return 0
211211
c = 0
212212
for ti in t.parameters
213-
c = max(c, unioncomplexity(ti))
213+
c = max(c, unioncomplexity(ti)::Int)
214214
end
215215
return c
216216
end
217-
unioncomplexity(u::UnionAll) = max(unioncomplexity(u.body), unioncomplexity(u.var.ub))
218-
unioncomplexity(t::Core.TypeofVararg) = isdefined(t, :T) ? unioncomplexity(t.T) : 0
217+
unioncomplexity(u::UnionAll) = max(unioncomplexity(u.body)::Int, unioncomplexity(u.var.ub)::Int)
218+
unioncomplexity(t::Core.TypeofVararg) = isdefined(t, :T) ? unioncomplexity(t.T)::Int : 0
219219
unioncomplexity(@nospecialize(x)) = 0
220220

221221
function improvable_via_constant_propagation(@nospecialize(t))

0 commit comments

Comments
 (0)