Skip to content

Commit 749b000

Browse files
committed
optimizer: fully support inlining of union-split, partially constant-prop' callsite (#43347)
Makes full use of constant-propagation, by addressing this [TODO](https://github.com/JuliaLang/julia/blob/00734c5fd045316a00d287ca2c0ec1a2eef6e4d1/base/compiler/ssair/inlining.jl#L1212). Here is a performance improvement from #43287: ```julia ulia> using BenchmarkTools julia> X = rand(ComplexF32, 64, 64); julia> dst = reinterpret(reshape, Float32, X); julia> src = copy(dst); julia> @Btime copyto!($dst, $src); 50.819 μs (1 allocation: 32 bytes) # v1.6.4 41.081 μs (0 allocations: 0 bytes) # this commit ``` fixes #43287
1 parent 85fc5c9 commit 749b000

File tree

4 files changed

+121
-105
lines changed

4 files changed

+121
-105
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
156156
# by constant analysis, but let's create `ConstCallInfo` if there has been any successful
157157
# constant propagation happened since other consumers may be interested in this
158158
if any_const_result && seen == napplicable
159+
@assert napplicable == nmatches(info) == length(const_results)
159160
info = ConstCallInfo(info, const_results)
160161
end
161162

base/compiler/ssair/inlining.jl

Lines changed: 86 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -675,24 +675,17 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
675675
new_stmt = Expr(:call, argexprs[2], def, state...)
676676
state1 = insert_node!(ir, idx, NewInstruction(new_stmt, call.rt))
677677
new_sig = with_atype(call_sig(ir, new_stmt)::Signature)
678-
info = call.info
679-
handled = false
680-
if isa(info, ConstCallInfo)
681-
if maybe_handle_const_call!(
682-
ir, state1.id, new_stmt, info, new_sig,
683-
istate, false, todo)
684-
handled = true
685-
else
686-
info = info.call
687-
end
688-
end
689-
if !handled && (isa(info, MethodMatchInfo) || isa(info, UnionSplitInfo))
690-
info = isa(info, MethodMatchInfo) ?
691-
MethodMatchInfo[info] : info.matches
678+
new_info = call.info
679+
if isa(new_info, ConstCallInfo)
680+
handle_const_call!(
681+
ir, state1.id, new_stmt, new_info,
682+
new_sig, istate, todo)
683+
elseif isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
684+
new_infos = isa(new_info, MethodMatchInfo) ? MethodMatchInfo[new_info] : new_info.matches
692685
# See if we can inline this call to `iterate`
693686
analyze_single_call!(
694687
ir, todo, state1.id, new_stmt,
695-
new_sig, info, istate)
688+
new_sig, new_infos, istate)
696689
end
697690
if i != length(thisarginfo.each)
698691
valT = getfield_tfunc(call.rt, Const(1))
@@ -1200,49 +1193,38 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta
12001193
return sig
12011194
end
12021195

1203-
# TODO inline non-`isdispatchtuple`, union-split callsites
1196+
# TODO inline non-`isdispatchtuple`, union-split callsites?
12041197
function analyze_single_call!(
12051198
ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
12061199
(; atypes, atype)::Signature, infos::Vector{MethodMatchInfo}, state::InliningState)
12071200
cases = InliningCase[]
12081201
local signature_union = Bottom
12091202
local only_method = nothing # keep track of whether there is one matching method
1210-
local meth
1203+
local meth::MethodLookupResult
12111204
local fully_covered = true
12121205
for i in 1:length(infos)
1213-
info = infos[i]
1214-
meth = info.results
1206+
meth = infos[i].results
12151207
if meth.ambig
12161208
# Too many applicable methods
12171209
# Or there is a (partial?) ambiguity
1218-
return
1210+
return nothing
12191211
elseif length(meth) == 0
12201212
# No applicable methods; try next union split
12211213
continue
1222-
elseif length(meth) == 1 && only_method !== false
1223-
if only_method === nothing
1224-
only_method = meth[1].method
1225-
elseif only_method !== meth[1].method
1214+
else
1215+
if length(meth) == 1 && only_method !== false
1216+
if only_method === nothing
1217+
only_method = meth[1].method
1218+
elseif only_method !== meth[1].method
1219+
only_method = false
1220+
end
1221+
else
12261222
only_method = false
12271223
end
1228-
else
1229-
only_method = false
12301224
end
12311225
for match in meth
1232-
spec_types = match.spec_types
1233-
signature_union = Union{signature_union, spec_types}
1234-
if !isdispatchtuple(spec_types)
1235-
fully_covered = false
1236-
continue
1237-
end
1238-
item = analyze_method!(match, atypes, state)
1239-
if item === nothing
1240-
fully_covered = false
1241-
continue
1242-
elseif _any(case->case.sig === spec_types, cases)
1243-
continue
1244-
end
1245-
push!(cases, InliningCase(spec_types, item))
1226+
signature_union = Union{signature_union, match.spec_types}
1227+
fully_covered &= handle_match!(match, atypes, state, cases)
12461228
end
12471229
end
12481230

@@ -1253,9 +1235,8 @@ function analyze_single_call!(
12531235
if length(infos) > 1
12541236
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
12551237
atype, only_method.sig)::SimpleVector
1256-
match = MethodMatch(metharg, methsp, only_method, true)
1238+
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
12571239
else
1258-
meth = meth::MethodLookupResult
12591240
@assert length(meth) == 1
12601241
match = meth[1]
12611242
end
@@ -1268,46 +1249,41 @@ function analyze_single_call!(
12681249
fully_covered = false
12691250
end
12701251

1271-
# If we only have one case and that case is fully covered, we may either
1272-
# be able to do the inlining now (for constant cases), or push it directly
1273-
# onto the todo list
1274-
if fully_covered && length(cases) == 1
1275-
handle_single_case!(ir, stmt, idx, cases[1].item, false, todo)
1276-
elseif length(cases) > 0
1277-
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
1278-
end
1279-
return nothing
1252+
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
12801253
end
12811254

1282-
# try to create `InliningCase`s using constant-prop'ed results
1283-
# currently it works only when constant-prop' succeeded for all (union-split) signatures
1284-
# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later
1285-
# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out
1286-
function maybe_handle_const_call!(
1287-
ir::IRCode, idx::Int, stmt::Expr, (; results)::ConstCallInfo, (; atypes, atype)::Signature,
1288-
state::InliningState, isinvoke::Bool, todo::Vector{Pair{Int, Any}})
1289-
cases = InliningCase[] # TODO avoid this allocation for single cases ?
1255+
# similar to `analyze_single_call!`, but with constant results
1256+
function handle_const_call!(
1257+
ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo,
1258+
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
1259+
(; atypes, atype) = sig
1260+
(; call, results) = cinfo
1261+
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
1262+
cases = InliningCase[]
12901263
local fully_covered = true
12911264
local signature_union = Bottom
1292-
for result in results
1293-
isa(result, InferenceResult) || return false
1294-
(; mi) = item = InliningTodo(result, atypes)
1295-
spec_types = mi.specTypes
1296-
signature_union = Union{signature_union, spec_types}
1297-
if !isdispatchtuple(spec_types)
1298-
fully_covered = false
1299-
continue
1300-
end
1301-
if !validate_sparams(mi.sparam_vals)
1302-
fully_covered = false
1265+
local j = 0
1266+
for i in 1:length(infos)
1267+
meth = infos[i].results
1268+
if meth.ambig
1269+
# Too many applicable methods
1270+
# Or there is a (partial?) ambiguity
1271+
return nothing
1272+
elseif length(meth) == 0
1273+
# No applicable methods; try next union split
13031274
continue
13041275
end
1305-
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1306-
if item === nothing
1307-
fully_covered = false
1308-
continue
1276+
for match in meth
1277+
j += 1
1278+
result = results[j]
1279+
if result === nothing
1280+
signature_union = Union{signature_union, match.spec_types}
1281+
fully_covered &= handle_match!(match, atypes, state, cases)
1282+
else
1283+
signature_union = Union{signature_union, result.linfo.specTypes}
1284+
fully_covered &= handle_const_result!(result, atypes, state, cases)
1285+
end
13091286
end
1310-
push!(cases, InliningCase(spec_types, item))
13111287
end
13121288

13131289
# if the signature is fully covered and there is only one applicable method,
@@ -1316,25 +1292,54 @@ function maybe_handle_const_call!(
13161292
if length(cases) == 0 && length(results) == 1
13171293
(; mi) = item = InliningTodo(results[1]::InferenceResult, atypes)
13181294
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1319-
validate_sparams(mi.sparam_vals) || return true
1320-
item === nothing && return true
1295+
validate_sparams(mi.sparam_vals) || return nothing
1296+
item === nothing && return nothing
13211297
push!(cases, InliningCase(mi.specTypes, item))
13221298
fully_covered = true
13231299
end
13241300
else
13251301
fully_covered = false
13261302
end
13271303

1304+
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
1305+
end
1306+
1307+
function handle_match!(
1308+
match::MethodMatch, argtypes::Vector{Any}, state::InliningState,
1309+
cases::Vector{InliningCase})
1310+
spec_types = match.spec_types
1311+
isdispatchtuple(spec_types) || return false
1312+
item = analyze_method!(match, argtypes, state)
1313+
item === nothing && return false
1314+
_any(case->case.sig === spec_types, cases) && return true
1315+
push!(cases, InliningCase(spec_types, item))
1316+
return true
1317+
end
1318+
1319+
function handle_const_result!(
1320+
result::InferenceResult, argtypes::Vector{Any}, state::InliningState,
1321+
cases::Vector{InliningCase})
1322+
(; mi) = item = InliningTodo(result, argtypes)
1323+
spec_types = mi.specTypes
1324+
isdispatchtuple(spec_types) || return false
1325+
validate_sparams(mi.sparam_vals) || return false
1326+
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1327+
item === nothing && return false
1328+
push!(cases, InliningCase(spec_types, item))
1329+
return true
1330+
end
1331+
1332+
function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, sig::Signature,
1333+
cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}})
13281334
# If we only have one case and that case is fully covered, we may either
13291335
# be able to do the inlining now (for constant cases), or push it directly
13301336
# onto the todo list
13311337
if fully_covered && length(cases) == 1
13321338
handle_single_case!(ir, stmt, idx, cases[1].item, isinvoke, todo)
13331339
elseif length(cases) > 0
1334-
isinvoke && rewrite_invoke_exprargs!(stmt)
1335-
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
1340+
push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases))
13361341
end
1337-
return true
1342+
return nothing
13381343
end
13391344

13401345
function handle_const_opaque_closure_call!(
@@ -1371,9 +1376,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13711376
ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE
13721377
info = info.info
13731378
end
1374-
1375-
# Inference determined this couldn't be analyzed. Don't question it.
13761379
if info === false
1380+
# Inference determined this couldn't be analyzed. Don't question it.
13771381
continue
13781382
end
13791383

base/compiler/stmtinfo.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,27 @@ struct UnionSplitInfo
4040
matches::Vector{MethodMatchInfo}
4141
end
4242

43+
nmatches(info::MethodMatchInfo) = length(info.results)
44+
function nmatches(info::UnionSplitInfo)
45+
n = 0
46+
for mminfo in info.matches
47+
n += nmatches(mminfo)
48+
end
49+
return n
50+
end
51+
52+
"""
53+
info::ConstCallInfo
54+
55+
The precision of this call was improved using constant information.
56+
In addition to the original call information `info.call`, this info also keeps
57+
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
58+
"""
59+
struct ConstCallInfo
60+
call::Union{MethodMatchInfo,UnionSplitInfo}
61+
results::Vector{Union{Nothing,InferenceResult}}
62+
end
63+
4364
"""
4465
struct CallMeta
4566
@@ -88,18 +109,6 @@ struct UnionSplitApplyCallInfo
88109
infos::Vector{ApplyCallInfo}
89110
end
90111

91-
"""
92-
struct ConstCallInfo
93-
94-
Precision for this call was improved using constant information. This info
95-
keeps a reference to the result that was used (or created for these)
96-
constant information.
97-
"""
98-
struct ConstCallInfo
99-
call::Any
100-
results::Vector{Union{Nothing,InferenceResult}}
101-
end
102-
103112
"""
104113
struct InvokeCallInfo
105114

test/compiler/inline.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -438,17 +438,19 @@ end
438438
import Base: @constprop
439439

440440
# test union-split callsite with successful and unsuccessful constant-prop' results
441-
@constprop :aggressive @inline f42840(xs, a::Int) = xs[a] # should be successful, and inlined
442-
@constprop :none @noinline f42840(xs::AbstractVector, a::Int) = xs[a] # should be unsuccessful, but still statically resolved
443-
let src = code_typed1((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs
444-
f42840(xs, 2)
445-
end
446-
@test count(src.code) do @nospecialize x
447-
iscall((src, getfield), x) # `(xs::Tuple{Int,Int,Int})[a::Const(2)]` => `getfield(xs, 2)`
448-
end == 1
449-
@test count(src.code) do @nospecialize x
450-
isinvoke(:f42840, x)
451-
end == 1
441+
# (also for https://github.com/JuliaLang/julia/issues/43287)
442+
@constprop :aggressive @inline f42840(cond::Bool, xs::Tuple, a::Int) = # should be successful, and inlined with constant prop' result
443+
cond ? xs[a] : @noinline(length(xs))
444+
@constprop :none @noinline f42840(::Bool, xs::AbstractVector, a::Int) = # should be unsuccessful, but still statically resolved
445+
xs[a]
446+
let src = code_typed((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs
447+
f42840(true, xs, 2)
448+
end |> only |> first
449+
# `f43287(true, xs::Tuple{Int,Int,Int}, 2)` => `getfield(xs, 2)`
450+
# `f43287(true, xs::Vector{Int}, 2)` => `:invoke f43287(true, xs, 2)`
451+
@test count(iscall((src, getfield)), src.code) == 1
452+
@test count(isinvoke(:length), src.code) == 0
453+
@test count(isinvoke(:f42840), src.code) == 1
452454
end
453455
# a bit weird, but should handle this kind of case as well
454456
@constprop :aggressive @noinline g42840(xs, a::Int) = xs[a] # should be successful, but only statically resolved

0 commit comments

Comments
 (0)