Skip to content

Commit f820a6a

Browse files
committed
optimizer: fully support inlining of union-split, partially constant-prop' callsite
Makes full use of constant-propagation, by addressing this [TODO](https://github.com/JuliaLang/julia/blob/00734c5fd045316a00d287ca2c0ec1a2eef6e4d1/base/compiler/ssair/inlining.jl#L1212). Here is a performance improvement from #43287: ```julia ulia> using BenchmarkTools julia> X = rand(ComplexF32, 64, 64); julia> dst = reinterpret(reshape, Float32, X); julia> src = copy(dst); julia> @Btime copyto!($dst, $src); 50.819 μs (1 allocation: 32 bytes) # v1.6.4 41.081 μs (0 allocations: 0 bytes) # this commit ``` fixes #43287
1 parent 00734c5 commit f820a6a

File tree

4 files changed

+113
-90
lines changed

4 files changed

+113
-90
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
139139
# by constant analysis, but let's create `ConstCallInfo` if there has been any successful
140140
# constant propagation happened since other consumers may be interested in this
141141
if any_const_result && seen == napplicable
142+
@assert napplicable == nmatches(info) == length(const_results)
142143
info = ConstCallInfo(info, const_results)
143144
end
144145

base/compiler/ssair/inlining.jl

Lines changed: 82 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -675,19 +675,16 @@ function rewrite_apply_exprargs!(
675675
new_sig = with_atype(call_sig(ir, new_stmt)::Signature)
676676
new_info = call.info
677677
if isa(new_info, ConstCallInfo)
678-
maybe_handle_const_call!(
678+
handle_const_call!(
679679
ir, state1.id, new_stmt, new_info, flag,
680-
new_sig, istate, todo) && @goto analyzed
681-
new_info = new_info.call # cascade to the non-constant handling
682-
end
683-
if isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
680+
new_sig, istate, todo)
681+
elseif isa(new_info, MethodMatchInfo) || isa(new_info, UnionSplitInfo)
684682
new_infos = isa(new_info, MethodMatchInfo) ? MethodMatchInfo[new_info] : new_info.matches
685683
# See if we can inline this call to `iterate`
686684
analyze_single_call!(
687685
ir, state1.id, new_stmt, new_infos, flag,
688686
new_sig, istate, todo)
689687
end
690-
@label analyzed
691688
if i != length(thisarginfo.each)
692689
valT = getfield_tfunc(call.rt, Const(1))
693690
val_extracted = insert_node!(ir, idx, NewInstruction(
@@ -1129,139 +1126,150 @@ function process_simple!(ir::IRCode, idx::Int, state::InliningState, todo::Vecto
11291126
return stmt, sig
11301127
end
11311128

1132-
# TODO inline non-`isdispatchtuple`, union-split callsites
1129+
# TODO inline non-`isdispatchtuple`, union-split callsites?
11331130
function analyze_single_call!(
11341131
ir::IRCode, idx::Int, stmt::Expr, infos::Vector{MethodMatchInfo}, flag::UInt8,
11351132
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
11361133
(; argtypes, atype) = sig
11371134
cases = InliningCase[]
11381135
local signature_union = Bottom
11391136
local only_method = nothing # keep track of whether there is one matching method
1140-
local meth
1137+
local meth::MethodLookupResult
11411138
local fully_covered = true
11421139
for i in 1:length(infos)
1143-
info = infos[i]
1144-
meth = info.results
1140+
meth = infos[i].results
11451141
if meth.ambig
11461142
# Too many applicable methods
11471143
# Or there is a (partial?) ambiguity
1148-
return
1144+
return nothing
11491145
elseif length(meth) == 0
11501146
# No applicable methods; try next union split
11511147
continue
1152-
elseif length(meth) == 1 && only_method !== false
1153-
if only_method === nothing
1154-
only_method = meth[1].method
1155-
elseif only_method !== meth[1].method
1148+
else
1149+
if length(meth) == 1 && only_method !== false
1150+
if only_method === nothing
1151+
only_method = meth[1].method
1152+
elseif only_method !== meth[1].method
1153+
only_method = false
1154+
end
1155+
else
11561156
only_method = false
11571157
end
1158-
else
1159-
only_method = false
11601158
end
11611159
for match in meth
1162-
spec_types = match.spec_types
1163-
signature_union = Union{signature_union, spec_types}
1164-
if !isdispatchtuple(spec_types)
1165-
fully_covered = false
1166-
continue
1167-
end
1168-
item = analyze_method!(match, argtypes, flag, state)
1169-
if item === nothing
1170-
fully_covered = false
1171-
continue
1172-
elseif _any(case->case.sig === spec_types, cases)
1173-
continue
1174-
end
1175-
push!(cases, InliningCase(spec_types, item))
1160+
signature_union = Union{signature_union, match.spec_types}
1161+
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
11761162
end
11771163
end
11781164

1179-
# if the signature is fully or mostly covered and there is only one applicable method,
1165+
# if the signature is fully covered and there is only one applicable method,
11801166
# we can try to inline it even if the signature is not a dispatch tuple
11811167
if length(cases) == 0 && only_method isa Method
11821168
if length(infos) > 1
11831169
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
11841170
atype, only_method.sig)::SimpleVector
11851171
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
11861172
else
1187-
meth = meth::MethodLookupResult
11881173
@assert length(meth) == 1
11891174
match = meth[1]
11901175
end
11911176
item = analyze_method!(match, argtypes, flag, state)
1192-
item === nothing && return
1177+
item === nothing && return nothing
11931178
push!(cases, InliningCase(match.spec_types, item))
11941179
fully_covered = match.fully_covers
11951180
else
11961181
fully_covered &= atype <: signature_union
11971182
end
11981183

1199-
# If we only have one case and that case is fully covered, we may either
1200-
# be able to do the inlining now (for constant cases), or push it directly
1201-
# onto the todo list
1202-
if fully_covered && length(cases) == 1
1203-
handle_single_case!(ir, idx, stmt, cases[1].item, todo)
1204-
elseif length(cases) > 0
1205-
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
1206-
end
1207-
return nothing
1184+
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
12081185
end
12091186

1210-
# try to create `InliningCase`s using constant-prop'ed results
1211-
# currently it works only when constant-prop' succeeded for all (union-split) signatures
1212-
# TODO use any of constant-prop'ed results, and leave the other unhandled cases to later
1213-
# TODO this function contains a lot of duplications with `analyze_single_call!`, factor them out
1214-
function maybe_handle_const_call!(
1215-
ir::IRCode, idx::Int, stmt::Expr, info::ConstCallInfo, flag::UInt8,
1187+
# similar to `analyze_single_call!`, but with constant results
1188+
function handle_const_call!(
1189+
ir::IRCode, idx::Int, stmt::Expr, cinfo::ConstCallInfo, flag::UInt8,
12161190
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
12171191
(; argtypes, atype) = sig
1218-
results = info.results
1219-
cases = InliningCase[] # TODO avoid this allocation for single cases ?
1192+
(; call, results) = cinfo
1193+
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
1194+
cases = InliningCase[]
12201195
local fully_covered = true
12211196
local signature_union = Bottom
1222-
for result in results
1223-
isa(result, InferenceResult) || return false
1224-
(; mi) = item = InliningTodo(result, argtypes)
1225-
spec_types = mi.specTypes
1226-
signature_union = Union{signature_union, spec_types}
1227-
if !isdispatchtuple(spec_types)
1228-
fully_covered = false
1229-
continue
1230-
end
1231-
if !validate_sparams(mi.sparam_vals)
1232-
fully_covered = false
1197+
local j = 0
1198+
for i in 1:length(infos)
1199+
meth = infos[i].results
1200+
if meth.ambig
1201+
# Too many applicable methods
1202+
# Or there is a (partial?) ambiguity
1203+
return nothing
1204+
elseif length(meth) == 0
1205+
# No applicable methods; try next union split
12331206
continue
12341207
end
1235-
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1236-
if item === nothing
1237-
fully_covered = false
1238-
continue
1208+
for match in meth
1209+
j += 1
1210+
result = results[j]
1211+
if result === nothing
1212+
signature_union = Union{signature_union, match.spec_types}
1213+
fully_covered &= handle_match!(match, argtypes, flag, state, cases)
1214+
else
1215+
signature_union = Union{signature_union, result.linfo.specTypes}
1216+
fully_covered &= handle_const_result!(result, argtypes, flag, state, cases)
1217+
end
12391218
end
1240-
push!(cases, InliningCase(spec_types, item))
12411219
end
12421220

12431221
# if the signature is fully covered and there is only one applicable method,
12441222
# we can try to inline it even if the signature is not a dispatch tuple
12451223
if length(cases) == 0 && length(results) == 1
12461224
(; mi) = item = InliningTodo(results[1]::InferenceResult, argtypes)
12471225
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1248-
validate_sparams(mi.sparam_vals) || return true
1249-
item === nothing && return true
1226+
validate_sparams(mi.sparam_vals) || return nothing
1227+
item === nothing && return nothing
12501228
push!(cases, InliningCase(mi.specTypes, item))
12511229
fully_covered = atype <: mi.specTypes
12521230
else
12531231
fully_covered &= atype <: signature_union
12541232
end
12551233

1234+
handle_cases!(ir, idx, stmt, sig, cases, fully_covered, todo)
1235+
end
1236+
1237+
function handle_match!(
1238+
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
1239+
cases::Vector{InliningCase})
1240+
spec_types = match.spec_types
1241+
isdispatchtuple(spec_types) || return false
1242+
item = analyze_method!(match, argtypes, flag, state)
1243+
item === nothing && return false
1244+
_any(case->case.sig === spec_types, cases) && return true
1245+
push!(cases, InliningCase(spec_types, item))
1246+
return true
1247+
end
1248+
1249+
function handle_const_result!(
1250+
result::InferenceResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
1251+
cases::Vector{InliningCase})
1252+
(; mi) = item = InliningTodo(result, argtypes)
1253+
spec_types = mi.specTypes
1254+
isdispatchtuple(spec_types) || return false
1255+
validate_sparams(mi.sparam_vals) || return false
1256+
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1257+
item === nothing && return false
1258+
push!(cases, InliningCase(spec_types, item))
1259+
return true
1260+
end
1261+
1262+
function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, sig::Signature,
1263+
cases::Vector{InliningCase}, fully_covered::Bool, todo::Vector{Pair{Int, Any}})
12561264
# If we only have one case and that case is fully covered, we may either
12571265
# be able to do the inlining now (for constant cases), or push it directly
12581266
# onto the todo list
12591267
if fully_covered && length(cases) == 1
12601268
handle_single_case!(ir, idx, stmt, cases[1].item, todo)
12611269
elseif length(cases) > 0
1262-
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
1270+
push!(todo, idx=>UnionSplit(fully_covered, sig.atype, cases))
12631271
end
1264-
return true
1272+
return nothing
12651273
end
12661274

12671275
function handle_const_opaque_closure_call!(
@@ -1327,10 +1335,10 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13271335
# if inference arrived here with constant-prop'ed result(s),
13281336
# we can perform a specialized analysis for just this case
13291337
if isa(info, ConstCallInfo)
1330-
maybe_handle_const_call!(
1338+
handle_const_call!(
13311339
ir, idx, stmt, info, flag,
1332-
sig, state, todo) && continue
1333-
info = info.call # cascade to the non-constant handling
1340+
sig, state, todo)
1341+
continue
13341342
end
13351343

13361344
# Ok, now figure out what method to call

base/compiler/stmtinfo.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,27 @@ struct UnionSplitInfo
3838
matches::Vector{MethodMatchInfo}
3939
end
4040

41+
nmatches(info::MethodMatchInfo) = length(info.results)
42+
function nmatches(info::UnionSplitInfo)
43+
n = 0
44+
for mminfo in info.matches
45+
n += nmatches(mminfo)
46+
end
47+
return n
48+
end
49+
50+
"""
51+
info::ConstCallInfo
52+
53+
The precision of this call was improved using constant information.
54+
In addition to the original call information `info.call`, this info also keeps
55+
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
56+
"""
57+
struct ConstCallInfo
58+
call::Union{MethodMatchInfo,UnionSplitInfo}
59+
results::Vector{Union{Nothing,InferenceResult}}
60+
end
61+
4162
"""
4263
info::MethodResultPure
4364
@@ -92,18 +113,6 @@ struct UnionSplitApplyCallInfo
92113
infos::Vector{ApplyCallInfo}
93114
end
94115

95-
"""
96-
info::ConstCallInfo
97-
98-
The precision of this call was improved using constant information.
99-
In addition to the original call information `info.call`, this info also keeps
100-
the inference results with constant information `info.results::Vector{Union{Nothing,InferenceResult}}`.
101-
"""
102-
struct ConstCallInfo
103-
call::Union{MethodMatchInfo,UnionSplitInfo}
104-
results::Vector{Union{Nothing,InferenceResult}}
105-
end
106-
107116
"""
108117
info::InvokeCallInfo
109118

test/compiler/inline.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,13 +758,18 @@ end
758758
import Base: @constprop
759759

760760
# test union-split callsite with successful and unsuccessful constant-prop' results
761-
@constprop :aggressive @inline f42840(xs, a::Int) = xs[a] # should be successful, and inlined
762-
@constprop :none @noinline f42840(xs::AbstractVector, a::Int) = xs[a] # should be unsuccessful, but still statically resolved
761+
# (also for https://github.com/JuliaLang/julia/issues/43287)
762+
@constprop :aggressive @inline f42840(cond::Bool, xs::Tuple, a::Int) = # should be successful, and inlined with constant prop' result
763+
cond ? xs[a] : @noinline(length(xs))
764+
@constprop :none @noinline f42840(::Bool, xs::AbstractVector, a::Int) = # should be unsuccessful, but still statically resolved
765+
xs[a]
763766
let src = code_typed((Union{Tuple{Int,Int,Int}, Vector{Int}},)) do xs
764-
f42840(xs, 2)
767+
f42840(true, xs, 2)
765768
end |> only |> first
766-
# `(xs::Tuple{Int,Int,Int})[a::Const(2)]` => `getfield(xs, 2)`
769+
# `f43287(true, xs::Tuple{Int,Int,Int}, 2)` => `getfield(xs, 2)`
770+
# `f43287(true, xs::Vector{Int}, 2)` => `:invoke f43287(true, xs, 2)`
767771
@test count(iscall((src, getfield)), src.code) == 1
772+
@test count(isinvoke(:length), src.code) == 0
768773
@test count(isinvoke(:f42840), src.code) == 1
769774
end
770775
# a bit weird, but should handle this kind of case as well

0 commit comments

Comments
 (0)