Skip to content

Commit 590a384

Browse files
committed
optimizer: fully support inlining of union-split, partially constant-prop' callsite (#43347)
Makes full use of constant-propagation, by addressing this [TODO](https://github.com/JuliaLang/julia/blob/00734c5fd045316a00d287ca2c0ec1a2eef6e4d1/base/compiler/ssair/inlining.jl#L1212). Here is a performance improvement from #43287: ```julia ulia> using BenchmarkTools julia> X = rand(ComplexF32, 64, 64); julia> dst = reinterpret(reshape, Float32, X); julia> src = copy(dst); julia> @Btime copyto!($dst, $src); 50.819 μs (1 allocation: 32 bytes) # v1.6.4 41.081 μs (0 allocations: 0 bytes) # this commit ``` fixes #43287
1 parent 85fc5c9 commit 590a384

File tree

4 files changed

+133
-115
lines changed

4 files changed

+133
-115
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
156156
# by constant analysis, but let's create `ConstCallInfo` if there has been any successful
157157
# constant propagation happened since other consumers may be interested in this
158158
if any_const_result && seen == napplicable
159+
@assert napplicable == nmatches(info) == length(const_results)
159160
info = ConstCallInfo(info, const_results)
160161
end
161162

base/compiler/ssair/inlining.jl

Lines changed: 98 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -675,24 +675,17 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
675675
new_stmt = Expr(:call, argexprs[2], def, state...)
676676
state1 = insert_node!(ir, idx, NewInstruction(new_stmt, call.rt))
677677
new_sig = with_atype(call_sig(ir, new_stmt)::Signature)
678-
info = call.info
679-
handled = false
680-
if isa(info, ConstCallInfo)
681-
if maybe_handle_const_call!(
682-
ir, state1.id, new_stmt, info, new_sig,
683-
istate, false, todo)
684-
handled = true
685-
else
686-
info = info.call
687-
end
688-
end
689-
if !handled && (isa(info, MethodMatchInfo) || isa(info, UnionSplitInfo))
690-
info = isa(info, MethodMatchInfo) ?
691-
MethodMatchInfo[info] : info.matches
678+
new_info = call.info
679+
if isa(new_info, ConstCallInfo)
680+
handle_const_call!(
681+
ir, state1.id, new_stmt, new_info,
682+
new_sig, istate, todo)
683+
elseif isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
684+
new_infos = isa(new_info, MethodMatchInfo) ? MethodMatchInfo[new_info] : new_info.matches
692685
# See if we can inline this call to `iterate`
693686
analyze_single_call!(
694687
ir, todo, state1.id, new_stmt,
695-
new_sig, info, istate)
688+
new_sig, new_infos, istate)
696689
end
697690
if i != length(thisarginfo.each)
698691
valT = getfield_tfunc(call.rt, Const(1))
@@ -910,7 +903,9 @@ function iterate(split::UnionSplitSignature, state::Vector{Int}...)
910903
return (sig, state)
911904
end
912905

913-
function handle_single_case!(ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case), isinvoke::Bool, todo::Vector{Pair{Int, Any}})
906+
function handle_single_case!(
907+
ir::IRCode, stmt::Expr, idx::Int, @nospecialize(case),
908+
todo::Vector{Pair{Int, Any}}, isinvoke::Bool = false)
914909
if isa(case, ConstantCase)
915910
ir[SSAValue(idx)] = case.val
916911
elseif isa(case, MethodInstance)
@@ -1086,13 +1081,13 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result):
10861081
validate_sparams(mi.sparam_vals) || return nothing
10871082
if argtypes_to_type(atypes) <: mi.def.sig
10881083
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1089-
handle_single_case!(ir, stmt, idx, item, true, todo)
1084+
handle_single_case!(ir, stmt, idx, item, todo, true)
10901085
return nothing
10911086
end
10921087
end
10931088

10941089
result = analyze_method!(match, atypes, state)
1095-
handle_single_case!(ir, stmt, idx, result, true, todo)
1090+
handle_single_case!(ir, stmt, idx, result, todo, true)
10961091
return nothing
10971092
end
10981093

@@ -1200,49 +1195,39 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta
12001195
return sig
12011196
end
12021197

1203-
# TODO inline non-`isdispatchtuple`, union-split callsites
1198+
# TODO inline non-`isdispatchtuple`, union-split callsites?
12041199
function analyze_single_call!(
12051200
ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
1206-
(; atypes, atype)::Signature, infos::Vector{MethodMatchInfo}, state::InliningState)
1201+
sig::Signature, infos::Vector{MethodMatchInfo}, state::InliningState)
1202+
(; atypes, atype) = sig
12071203
cases = InliningCase[]
12081204
local signature_union = Bottom
12091205
local only_method = nothing # keep track of whether there is one matching method
1210-
local meth
1206+
local meth::MethodLookupResult
12111207
local fully_covered = true
12121208
for i in 1:length(infos)
1213-
info = infos[i]
1214-
meth = info.results
1209+
meth = infos[i].results
12151210
if meth.ambig
12161211
# Too many applicable methods
12171212
# Or there is a (partial?) ambiguity
1218-
return
1213+
return nothing
12191214
elseif length(meth) == 0
12201215
# No applicable methods; try next union split
12211216
continue
1222-
elseif length(meth) == 1 && only_method !== false
1223-
if only_method === nothing
1224-
only_method = meth[1].method
1225-
elseif only_method !== meth[1].method
1217+
else
1218+
if length(meth) == 1 && only_method !== false
1219+
if only_method === nothing
1220+
only_method = meth[1].method
1221+
elseif only_method !== meth[1].method
1222+
only_method = false
1223+
end
1224+
else
12261225
only_method = false
12271226
end
1228-
else
1229-
only_method = false
12301227
end
12311228
for match in meth
1232-
spec_types = match.spec_types
1233-
signature_union = Union{signature_union, spec_types}
1234-
if !isdispatchtuple(spec_types)
1235-
fully_covered = false
1236-
continue
1237-
end
1238-
item = analyze_method!(match, atypes, state)
1239-
if item === nothing
1240-
fully_covered = false
1241-
continue
1242-
elseif _any(case->case.sig === spec_types, cases)
1243-
continue
1244-
end
1245-
push!(cases, InliningCase(spec_types, item))
1229+
signature_union = Union{signature_union, match.spec_types}
1230+
fully_covered &= handle_match!(match, atypes, state, cases)
12461231
end
12471232
end
12481233

@@ -1253,9 +1238,8 @@ function analyze_single_call!(
12531238
if length(infos) > 1
12541239
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
12551240
atype, only_method.sig)::SimpleVector
1256-
match = MethodMatch(metharg, methsp, only_method, true)
1241+
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
12571242
else
1258-
meth = meth::MethodLookupResult
12591243
@assert length(meth) == 1
12601244
match = meth[1]
12611245
end
@@ -1268,46 +1252,41 @@ function analyze_single_call!(
12681252
fully_covered = false
12691253
end
12701254

1271-
# If we only have one case and that case is fully covered, we may either
1272-
# be able to do the inlining now (for constant cases), or push it directly
1273-
# onto the todo list
1274-
if fully_covered && length(cases) == 1
1275-
handle_single_case!(ir, stmt, idx, cases[1].item, false, todo)
1276-
elseif length(cases) > 0
1277-
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
1278-
end
1279-
return nothing
1255+
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
12801256
end
12811257

1282-
# try to create `InliningCase`s using constant-prop'ed results
1283-
# currently it works only when constant-prop' succeeded for all (union-split) signatures
1284-
# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later
1285-
# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out
1286-
function maybe_handle_const_call!(
1287-
ir::IRCode, idx::Int, stmt::Expr, (; results)::ConstCallInfo, (; atypes, atype)::Signature,
1288-
state::InliningState, isinvoke::Bool, todo::Vector{Pair{Int, Any}})
1289-
cases = InliningCase[] # TODO avoid this allocation for single cases ?
1258+
# similar to `analyze_single_call!`, but with constant results
1259+
function handle_const_call!(
1260+
ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo,
1261+
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
1262+
(; atypes, atype) = sig
1263+
(; call, results) = cinfo
1264+
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
1265+
cases = InliningCase[]
12901266
local fully_covered = true
12911267
local signature_union = Bottom
1292-
for result in results
1293-
isa(result, InferenceResult) || return false
1294-
(; mi) = item = InliningTodo(result, atypes)
1295-
spec_types = mi.specTypes
1296-
signature_union = Union{signature_union, spec_types}
1297-
if !isdispatchtuple(spec_types)
1298-
fully_covered = false
1299-
continue
1300-
end
1301-
if !validate_sparams(mi.sparam_vals)
1302-
fully_covered = false
1268+
local j = 0
1269+
for i in 1:length(infos)
1270+
meth = infos[i].results
1271+
if meth.ambig
1272+
# Too many applicable methods
1273+
# Or there is a (partial?) ambiguity
1274+
return nothing
1275+
elseif length(meth) == 0
1276+
# No applicable methods; try next union split
13031277
continue
13041278
end
1305-
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1306-
if item === nothing
1307-
fully_covered = false
1308-
continue
1279+
for match in meth
1280+
j += 1
1281+
result = results[j]
1282+
if result === nothing
1283+
signature_union = Union{signature_union, match.spec_types}
1284+
fully_covered &= handle_match!(match, atypes, state, cases)
1285+
else
1286+
signature_union = Union{signature_union, result.linfo.specTypes}
1287+
fully_covered &= handle_const_result!(result, atypes, state, cases)
1288+
end
13091289
end
1310-
push!(cases, InliningCase(spec_types, item))
13111290
end
13121291

13131292
# if the signature is fully covered and there is only one applicable method,
@@ -1316,25 +1295,54 @@ function maybe_handle_const_call!(
13161295
if length(cases) == 0 && length(results) == 1
13171296
(; mi) = item = InliningTodo(results[1]::InferenceResult, atypes)
13181297
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1319-
validate_sparams(mi.sparam_vals) || return true
1320-
item === nothing && return true
1298+
validate_sparams(mi.sparam_vals) || return nothing
1299+
item === nothing && return nothing
13211300
push!(cases, InliningCase(mi.specTypes, item))
13221301
fully_covered = true
13231302
end
13241303
else
13251304
fully_covered = false
13261305
end
13271306

1307+
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
1308+
end
1309+
1310+
function handle_match!(
1311+
match::MethodMatch, argtypes::Vector{Any}, state::InliningState,
1312+
cases::Vector{InliningCase})
1313+
spec_types = match.spec_types
1314+
isdispatchtuple(spec_types) || return false
1315+
item = analyze_method!(match, argtypes, state)
1316+
item === nothing && return false
1317+
_any(case->case.sig === spec_types, cases) && return true
1318+
push!(cases, InliningCase(spec_types, item))
1319+
return true
1320+
end
1321+
1322+
function handle_const_result!(
1323+
result::InferenceResult, argtypes::Vector{Any}, state::InliningState,
1324+
cases::Vector{InliningCase})
1325+
(; mi) = item = InliningTodo(result, argtypes)
1326+
spec_types = mi.specTypes
1327+
isdispatchtuple(spec_types) || return false
1328+
validate_sparams(mi.sparam_vals) || return false
1329+
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1330+
item === nothing && return false
1331+
push!(cases, InliningCase(spec_types, item))
1332+
return true
1333+
end
1334+
1335+
function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, sig::Signature,
1336+
cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}})
13281337
# If we only have one case and that case is fully covered, we may either
13291338
# be able to do the inlining now (for constant cases), or push it directly
13301339
# onto the todo list
13311340
if fully_covered && length(cases) == 1
1332-
handle_single_case!(ir, stmt, idx, cases[1].item, isinvoke, todo)
1341+
handle_single_case!(ir, stmt, idx, cases[1].item, todo)
13331342
elseif length(cases) > 0
1334-
isinvoke && rewrite_invoke_exprargs!(stmt)
1335-
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
1343+
push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases))
13361344
end
1337-
return true
1345+
return nothing
13381346
end
13391347

13401348
function handle_const_opaque_closure_call!(
@@ -1346,7 +1354,7 @@ function handle_const_opaque_closure_call!(
13461354
isdispatchtuple(item.mi.specTypes) || return
13471355
validate_sparams(item.mi.sparam_vals) || return
13481356
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1349-
handle_single_case!(ir, stmt, idx, item, false, todo)
1357+
handle_single_case!(ir, stmt, idx, item, todo)
13501358
return nothing
13511359
end
13521360

@@ -1371,9 +1379,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13711379
ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE
13721380
info = info.info
13731381
end
1374-
1375-
# Inference determined this couldn't be analyzed. Don't question it.
13761382
if info === false
1383+
# Inference determined this couldn't be analyzed. Don't question it.
13771384
continue
13781385
end
13791386

@@ -1386,16 +1393,15 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13861393
sig, state, todo)
13871394
continue
13881395
else
1389-
maybe_handle_const_call!(
1396+
handle_const_call!(
13901397
ir, idx, stmt, info, sig,
1391-
state, sig.f === Core.invoke, todo) && continue
1398+
state, todo)
13921399
end
1393-
info = info.call # cascade to the non-constant handling
13941400
end
13951401

13961402
if isa(info, OpaqueClosureCallInfo)
13971403
item = analyze_method!(info.match, sig.atypes, state)
1398-
handle_single_case!(ir, stmt, idx, item, false, todo)
1404+
handle_single_case!(ir, stmt, idx, item, todo)
13991405
continue
14001406
end
14011407

base/compiler/stmtinfo.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,27 @@ struct UnionSplitInfo
4040
matches::Vector{MethodMatchInfo}
4141
end
4242

43+
nmatches(info::MethodMatchInfo) = length(info.results)
44+
function nmatches(info::UnionSplitInfo)
45+
n = 0
46+
for mminfo in info.matches
47+
n += nmatches(mminfo)
48+
end
49+
return n
50+
end
51+
52+
"""
53+
info::ConstCallInfo
54+
55+
The precision of this call was improved using constant information.
56+
In addition to the original call information `info.call`, this info also keeps
57+
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
58+
"""
59+
struct ConstCallInfo
60+
call::Any
61+
results::Vector{Union{Nothing,InferenceResult}}
62+
end
63+
4364
"""
4465
struct CallMeta
4566
@@ -88,18 +109,6 @@ struct UnionSplitApplyCallInfo
88109
infos::Vector{ApplyCallInfo}
89110
end
90111

91-
"""
92-
struct ConstCallInfo
93-
94-
Precision for this call was improved using constant information. This info
95-
keeps a reference to the result that was used (or created for these)
96-
constant information.
97-
"""
98-
struct ConstCallInfo
99-
call::Any
100-
results::Vector{Union{Nothing,InferenceResult}}
101-
end
102-
103112
"""
104113
struct InvokeCallInfo
105114

test/compiler/inline.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -438,17 +438,19 @@ end
438438
import Base: @constprop
439439

440440
# test union-split callsite with successful and unsuccessful constant-prop' results
441-
@constprop :aggressive @inline f42840(xs, a::Int) = xs[a] # should be successful, and inlined
442-
@constprop :none @noinline f42840(xs::AbstractVector, a::Int) = xs[a] # should be unsuccessful, but still statically resolved
443-
let src = code_typed1((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs
444-
f42840(xs, 2)
445-
end
446-
@test count(src.code) do @nospecialize x
447-
iscall((src, getfield), x) # `(xs::Tuple{Int,Int,Int})[a::Const(2)]` => `getfield(xs, 2)`
448-
end == 1
449-
@test count(src.code) do @nospecialize x
450-
isinvoke(:f42840, x)
451-
end == 1
441+
# (also for https://github.com/JuliaLang/julia/issues/43287)
442+
@constprop :aggressive @inline f42840(cond::Bool, xs::Tuple, a::Int) = # should be successful, and inlined with constant prop' result
443+
cond ? xs[a] : length(xs)
444+
@constprop :none @noinline f42840(::Bool, xs::AbstractVector, a::Int) = # should be unsuccessful, but still statically resolved
445+
xs[a]
446+
let src = code_typed((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs
447+
f42840(true, xs, 2)
448+
end |> only |> first
449+
# `f43287(true, xs::Tuple{Int,Int,Int}, 2)` => `getfield(xs, 2)`
450+
# `f43287(true, xs::Vector{Int}, 2)` => `:invoke f43287(true, xs, 2)`
451+
@test count(x->iscall((src, getfield),x), src.code) == 1
452+
@test count(x->isinvoke(:length, x), src.code) == 0
453+
@test count(x->isinvoke(:f42840, x), src.code) == 1
452454
end
453455
# a bit weird, but should handle this kind of case as well
454456
@constprop :aggressive @noinline g42840(xs, a::Int) = xs[a] # should be successful, but only statically resolved

0 commit comments

Comments
 (0)