Skip to content

Commit a7e53d1

Browse files
committed
optimizer: fully support inlining of union-split, partially constant-prop' callsite
Makes full use of constant-propagation, by addressing this [TODO](https://github.com/JuliaLang/julia/blob/00734c5fd045316a00d287ca2c0ec1a2eef6e4d1/base/compiler/ssair/inlining.jl#L1212). Here is a performance improvement from #43287: ```julia ulia> using BenchmarkTools julia> X = rand(ComplexF32, 64, 64); julia> dst = reinterpret(reshape, Float32, X); julia> src = copy(dst); julia> @Btime copyto!($dst, $src); 50.819 μs (1 allocation: 32 bytes) # v1.6.4 41.081 μs (0 allocations: 0 bytes) # this commit ``` fixes #43287
1 parent 9bb0aeb commit a7e53d1

File tree

4 files changed

+113
-90
lines changed

4 files changed

+113
-90
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
139139
# by constant analysis, but let's create `ConstCallInfo` if there has been any successful
140140
# constant propagation happened since other consumers may be interested in this
141141
if any_const_result && seen == napplicable
142+
@assert napplicable == nmatches(info) == length(const_results)
142143
info = ConstCallInfo(info, const_results)
143144
end
144145

base/compiler/ssair/inlining.jl

Lines changed: 82 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -672,19 +672,16 @@ function rewrite_apply_exprargs!(
672672
new_sig = with_atype(call_sig(ir, new_stmt)::Signature)
673673
new_info = call.info
674674
if isa(new_info, ConstCallInfo)
675-
maybe_handle_const_call!(
675+
handle_const_call!(
676676
ir, state1.id, new_stmt, new_info, flag,
677-
new_sig, istate, todo) && @goto analyzed
678-
new_info = new_info.call # cascade to the non-constant handling
679-
end
680-
if isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
677+
new_sig, istate, todo)
678+
elseif isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
681679
new_infos = isa(new_info, MethodMatchInfo) ? MethodMatchInfo[new_info] : new_info.matches
682680
# See if we can inline this call to `iterate`
683681
analyze_single_call!(
684682
ir, state1.id, new_stmt, new_infos, flag,
685683
new_sig, istate, todo)
686684
end
687-
@label analyzed
688685
if i != length(thisarginfo.each)
689686
valT = getfield_tfunc(call.rt, Const(1))
690687
val_extracted = insert_node!(ir, idx, NewInstruction(
@@ -1126,139 +1123,150 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
11261123
return stmt, sig
11271124
end
11281125

1129-
# TODO inline non-`isdispatchtuple`, union-split callsites
1126+
# TODO inline non-`isdispatchtuple`, union-split callsites?
11301127
function analyze_single_call!(
11311128
ir::IRCode, idx::Int, stmt::Expr, infos::Vector{MethodMatchInfo}, flag::UInt8,
11321129
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
11331130
(; argtypes, atype) = sig
11341131
cases = InliningCase[]
11351132
local signature_union = Bottom
11361133
local only_method = nothing # keep track of whether there is one matching method
1137-
local meth
1134+
local meth::MethodLookupResult
11381135
local fully_covered = true
11391136
for i in 1:length(infos)
1140-
info = infos[i]
1141-
meth = info.results
1137+
meth = infos[i].results
11421138
if meth.ambig
11431139
# Too many applicable methods
11441140
# Or there is a (partial?) ambiguity
1145-
return
1141+
return nothing
11461142
elseif length(meth) == 0
11471143
# No applicable methods; try next union split
11481144
continue
1149-
elseif length(meth) == 1 && only_method !== false
1150-
if only_method === nothing
1151-
only_method = meth[1].method
1152-
elseif only_method !== meth[1].method
1145+
else
1146+
if length(meth) == 1 && only_method !== false
1147+
if only_method === nothing
1148+
only_method = meth[1].method
1149+
elseif only_method !== meth[1].method
1150+
only_method = false
1151+
end
1152+
else
11531153
only_method = false
11541154
end
1155-
else
1156-
only_method = false
11571155
end
11581156
for match in meth
1159-
spec_types = match.spec_types
1160-
signature_union = Union{signature_union, spec_types}
1161-
if !isdispatchtuple(spec_types)
1162-
fully_covered = false
1163-
continue
1164-
end
1165-
item = analyze_method!(match, argtypes, flag, state)
1166-
if item === nothing
1167-
fully_covered = false
1168-
continue
1169-
elseif _any(case->case.sig === spec_types, cases)
1170-
continue
1171-
end
1172-
push!(cases, InliningCase(spec_types, item))
1157+
signature_union = Union{signature_union, match.spec_types}
1158+
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
11731159
end
11741160
end
11751161

1176-
# if the signature is fully or mostly covered and there is only one applicable method,
1162+
# if the signature is fully covered and there is only one applicable method,
11771163
# we can try to inline it even if the signature is not a dispatch tuple
11781164
if length(cases) == 0 && only_method isa Method
11791165
if length(infos) > 1
11801166
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
11811167
atype, only_method.sig)::SimpleVector
11821168
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
11831169
else
1184-
meth = meth::MethodLookupResult
11851170
@assert length(meth) == 1
11861171
match = meth[1]
11871172
end
11881173
item = analyze_method!(match, argtypes, flag, state)
1189-
item === nothing && return
1174+
item === nothing && return nothing
11901175
push!(cases, InliningCase(match.spec_types, item))
11911176
fully_covered = match.fully_covers
11921177
else
11931178
fully_covered &= atype <: signature_union
11941179
end
11951180

1196-
# If we only have one case and that case is fully covered, we may either
1197-
# be able to do the inlining now (for constant cases), or push it directly
1198-
# onto the todo list
1199-
if fully_covered && length(cases) == 1
1200-
handle_single_case!(ir, idx, stmt, cases[1].item, todo)
1201-
elseif length(cases) > 0
1202-
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
1203-
end
1204-
return nothing
1181+
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
12051182
end
12061183

1207-
# try to create `InliningCase`s using constant-prop'ed results
1208-
# currently it works only when constant-prop' succeeded for all (union-split) signatures
1209-
# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later
1210-
# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out
1211-
function maybe_handle_const_call!(
1212-
ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, flag::UInt8,
1184+
# similar to `analyze_single_call!`, but with constant results
1185+
function handle_const_call!(
1186+
ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo, flag::UInt8,
12131187
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
12141188
(; argtypes, atype) = sig
1215-
results = info.results
1216-
cases = InliningCase[] # TODO avoid this allocation for single cases ?
1189+
(; call, results) = cinfo
1190+
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
1191+
cases = InliningCase[]
12171192
local fully_covered = true
12181193
local signature_union = Bottom
1219-
for result in results
1220-
isa(result, InferenceResult) || return false
1221-
(; mi) = item = InliningTodo(result, argtypes)
1222-
spec_types = mi.specTypes
1223-
signature_union = Union{signature_union, spec_types}
1224-
if !isdispatchtuple(spec_types)
1225-
fully_covered = false
1226-
continue
1227-
end
1228-
if !validate_sparams(mi.sparam_vals)
1229-
fully_covered = false
1194+
local j = 0
1195+
for i in 1:length(infos)
1196+
meth = infos[i].results
1197+
if meth.ambig
1198+
# Too many applicable methods
1199+
# Or there is a (partial?) ambiguity
1200+
return nothing
1201+
elseif length(meth) == 0
1202+
# No applicable methods; try next union split
12301203
continue
12311204
end
1232-
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1233-
if item === nothing
1234-
fully_covered = false
1235-
continue
1205+
for match in meth
1206+
j += 1
1207+
result = results[j]
1208+
if result === nothing
1209+
signature_union = Union{signature_union, match.spec_types}
1210+
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
1211+
else
1212+
signature_union = Union{signature_union, result.linfo.specTypes}
1213+
fully_covered &= handle_const_result!(result, argtypes, flag, state, cases)
1214+
end
12361215
end
1237-
push!(cases, InliningCase(spec_types, item))
12381216
end
12391217

12401218
# if the signature is fully covered and there is only one applicable method,
12411219
# we can try to inline it even if the signature is not a dispatch tuple
12421220
if length(cases) == 0 && length(results) == 1
12431221
(; mi) = item = InliningTodo(results[1]::InferenceResult, argtypes)
12441222
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1245-
validate_sparams(mi.sparam_vals) || return true
1246-
item === nothing && return true
1223+
validate_sparams(mi.sparam_vals) || return nothing
1224+
item === nothing && return nothing
12471225
push!(cases, InliningCase(mi.specTypes, item))
12481226
fully_covered = atype <: mi.specTypes
12491227
else
12501228
fully_covered &= atype <: signature_union
12511229
end
12521230

1231+
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
1232+
end
1233+
1234+
function handle_match!(
1235+
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
1236+
cases::Vector{InliningCase})
1237+
spec_types = match.spec_types
1238+
isdispatchtuple(spec_types) || return false
1239+
item = analyze_method!(match, argtypes, flag, state)
1240+
item === nothing && return false
1241+
_any(case->case.sig === spec_types, cases) && return true
1242+
push!(cases, InliningCase(spec_types, item))
1243+
return true
1244+
end
1245+
1246+
function handle_const_result!(
1247+
result::InferenceResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
1248+
cases::Vector{InliningCase})
1249+
(; mi) = item = InliningTodo(result, argtypes)
1250+
spec_types = mi.specTypes
1251+
isdispatchtuple(spec_types) || return false
1252+
validate_sparams(mi.sparam_vals) || return false
1253+
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1254+
item === nothing && return false
1255+
push!(cases, InliningCase(spec_types, item))
1256+
return true
1257+
end
1258+
1259+
function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, sig::Signature,
1260+
cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}})
12531261
# If we only have one case and that case is fully covered, we may either
12541262
# be able to do the inlining now (for constant cases), or push it directly
12551263
# onto the todo list
12561264
if fully_covered && length(cases) == 1
12571265
handle_single_case!(ir, idx, stmt, cases[1].item, todo)
12581266
elseif length(cases) > 0
1259-
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
1267+
push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases))
12601268
end
1261-
return true
1269+
return nothing
12621270
end
12631271

12641272
function handle_const_opaque_closure_call!(
@@ -1324,10 +1332,10 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13241332
# if inference arrived here with constant-prop'ed result(s),
13251333
# we can perform a specialized analysis for just this case
13261334
if isa(info, ConstCallInfo)
1327-
maybe_handle_const_call!(
1335+
handle_const_call!(
13281336
ir, idx, stmt, info, flag,
1329-
sig, state, todo) && continue
1330-
info = info.call # cascade to the non-constant handling
1337+
sig, state, todo)
1338+
continue
13311339
end
13321340

13331341
# Ok, now figure out what method to call

base/compiler/stmtinfo.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,27 @@ struct UnionSplitInfo
3838
matches::Vector{MethodMatchInfo}
3939
end
4040

41+
nmatches(info::MethodMatchInfo) = length(info.results)
42+
function nmatches(info::UnionSplitInfo)
43+
n = 0
44+
for mminfo in info.matches
45+
n += nmatches(mminfo)
46+
end
47+
return n
48+
end
49+
50+
"""
51+
info::ConstCallInfo
52+
53+
The precision of this call was improved using constant information.
54+
In addition to the original call information `info.call`, this info also keeps
55+
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
56+
"""
57+
struct ConstCallInfo
58+
call::Union{MethodMatchInfo,UnionSplitInfo}
59+
results::Vector{Union{Nothing,InferenceResult}}
60+
end
61+
4162
"""
4263
info::MethodResultPure
4364
@@ -92,18 +113,6 @@ struct UnionSplitApplyCallInfo
92113
infos::Vector{ApplyCallInfo}
93114
end
94115

95-
"""
96-
info::ConstCallInfo
97-
98-
The precision of this call was improved using constant information.
99-
In addition to the original call information `info.call`, this info also keeps
100-
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
101-
"""
102-
struct ConstCallInfo
103-
call::Union{MethodMatchInfo,UnionSplitInfo}
104-
results::Vector{Union{Nothing,InferenceResult}}
105-
end
106-
107116
"""
108117
info::InvokeCallInfo
109118

test/compiler/inline.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,13 +758,18 @@ end
758758
import Base: @constprop
759759

760760
# test union-split callsite with successful and unsuccessful constant-prop' results
761-
@constprop :aggressive @inline f42840(xs, a::Int) = xs[a] # should be successful, and inlined
762-
@constprop :none @noinline f42840(xs::AbstractVector, a::Int) = xs[a] # should be unsuccessful, but still statically resolved
761+
# (also for https://github.com/JuliaLang/julia/issues/43287)
762+
@constprop :aggressive @inline f42840(cond::Bool, xs::Tuple, a::Int) = # should be successful, and inlined with constant prop' result
763+
cond ? xs[a] : @noinline(length(xs))
764+
@constprop :none @noinline f42840(::Bool, xs::AbstractVector, a::Int) = # should be unsuccessful, but still statically resolved
765+
xs[a]
763766
let src = code_typed((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs
764-
f42840(xs, 2)
767+
f42840(true, xs, 2)
765768
end |> only |> first
766-
# `(xs::Tuple{Int,Int,Int})[a::Const(2)]` => `getfield(xs, 2)`
769+
# `f43287(true, xs::Tuple{Int,Int,Int}, 2)` => `getfield(xs, 2)`
770+
# `f43287(true, xs::Vector{Int}, 2)` => `:invoke f43287(true, xs, 2)`
767771
@test count(iscall((src, getfield)), src.code) == 1
772+
@test count(isinvoke(:length), src.code) == 0
768773
@test count(isinvoke(:f42840), src.code) == 1
769774
end
770775
# a bit weird, but should handle this kind of case as well

0 commit comments

Comments
 (0)