Skip to content

Commit 78dc224

Browse files
committed
more type-stable type-inference (#41697)
(this PR is the final output of my demo at [our workshop](https://github.com/aviatesk/juliacon2021-workshop-pkgdev)) This PR eliminated much of runtime dispatches within our type inference routine, that are reported by the following JET analysis: ```julia using JETTest const CC = Core.Compiler function function_filter(@nospecialize(ft)) ft === typeof(CC.isprimitivetype) && return false ft === typeof(CC.ismutabletype) && return false ft === typeof(CC.isbitstype) && return false ft === typeof(CC.widenconst) && return false ft === typeof(CC.widenconditional) && return false ft === typeof(CC.widenwrappedconditional) && return false ft === typeof(CC.maybe_extract_const_bool) && return false ft === typeof(CC.ignorelimited) && return false return true end function frame_filter((; linfo) = sv) meth = linfo.def isa(meth, Method) || return true return occursin("compiler/", string(meth.file)) end report_dispatch(CC.typeinf, (CC.NativeInterpreter, CC.InferenceState); function_filter, frame_filter) ``` > on master ``` ═════ 137 possible errors found ═════ ... ``` > on this PR ``` ═════ 51 possible errors found ═════ ... ``` And it seems like this PR makes JIT slightly faster: > on master ```julia ~/julia/julia master ❯ ./usr/bin/julia -e '@time using Plots; @time plot(rand(10,3));' 3.659865 seconds (7.19 M allocations: 497.982 MiB, 3.94% gc time, 0.39% compilation time) 2.696410 seconds (3.62 M allocations: 202.905 MiB, 7.49% gc time, 56.39% compilation time) ``` > on this PR ```julia ~/julia/julia avi/jetdemo* 7s ❯ ./usr/bin/julia -e '@time using Plots; @time plot(rand(10,3));' 3.396974 seconds (7.16 M allocations: 491.442 MiB, 4.80% gc time, 0.28% compilation time) 2.591130 seconds (3.48 M allocations: 196.026 MiB, 7.29% gc time, 56.72% compilation time) ``` cherry-picked from 795935f
1 parent 1232010 commit 78dc224

File tree

9 files changed

+90
-73
lines changed

9 files changed

+90
-73
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 62 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
8585
push!(edges, edge)
8686
end
8787
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
88-
const_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
89-
if const_rt !== rt && const_rt rt
90-
rt = const_rt
88+
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
89+
if const_result !== nothing
90+
const_rt, const_result = const_result
91+
if const_rt !== rt && const_rt rt
92+
rt = const_rt
93+
end
9194
end
9295
push!(const_results, const_result)
9396
if const_result !== nothing
@@ -107,9 +110,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
107110
# try constant propagation with argtypes for this match
108111
# this is in preparation for inlining, or improving the return result
109112
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
110-
const_this_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
111-
if const_this_rt !== this_rt && const_this_rt this_rt
112-
this_rt = const_this_rt
113+
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
114+
if const_result !== nothing
115+
const_this_rt, const_result = const_result
116+
if const_this_rt !== this_rt && const_this_rt this_rt
117+
this_rt = const_this_rt
118+
end
113119
end
114120
push!(const_results, const_result)
115121
if const_result !== nothing
@@ -520,33 +526,35 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
520526
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
521527
sv::InferenceState, va_override::Bool)
522528
mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv)
523-
mi === nothing && return Any, nothing
529+
mi === nothing && return nothing
524530
# try constant prop'
525531
inf_cache = get_inference_cache(interp)
526532
inf_result = cache_lookup(mi, argtypes, inf_cache)
527533
if inf_result === nothing
528534
# if there might be a cycle, check to make sure we don't end up
529535
# calling ourselves here.
530-
if result.edgecycle && _any(InfStackUnwind(sv)) do infstate
531-
# if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`)
532-
# we can relax the cycle detection by comparing `MethodInstance`s and allow inference to
533-
# propagate different constant elements if the recursion is finite over the lattice
534-
return (result.edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) &&
535-
any(infstate.result.overridden_by_const)
536+
let result = result # prevent capturing
537+
if result.edgecycle && _any(InfStackUnwind(sv)) do infstate
538+
# if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`)
539+
# we can relax the cycle detection by comparing `MethodInstance`s and allow inference to
540+
# propagate different constant elements if the recursion is finite over the lattice
541+
return (result.edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) &&
542+
any(infstate.result.overridden_by_const)
543+
end
544+
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
545+
return nothing
536546
end
537-
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
538-
return Any, nothing
539547
end
540548
inf_result = InferenceResult(mi, argtypes, va_override)
541549
frame = InferenceState(inf_result, #=cache=#false, interp)
542-
frame === nothing && return Any, nothing # this is probably a bad generated function (unsound), but just ignore it
550+
frame === nothing && return nothing # this is probably a bad generated function (unsound), but just ignore it
543551
frame.parent = sv
544552
push!(inf_cache, inf_result)
545-
typeinf(interp, frame) || return Any, nothing
553+
typeinf(interp, frame) || return nothing
546554
end
547555
result = inf_result.result
548556
# if constant inference hits a cycle, just bail out
549-
isa(result, InferenceState) && return Any, nothing
557+
isa(result, InferenceState) && return nothing
550558
add_backedge!(mi, sv)
551559
return result, inf_result
552560
end
@@ -1178,7 +1186,8 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
11781186
nargtype === Bottom && return CallMeta(Bottom, false)
11791187
nargtype isa DataType || return CallMeta(Any, false) # other cases are not implemented below
11801188
isdispatchelem(ft) || return CallMeta(Any, false) # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
1181-
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)
1189+
ft = ft::DataType
1190+
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)::Type
11821191
nargtype = Tuple{ft, nargtype.parameters...}
11831192
argtype = Tuple{ft, argtype.parameters...}
11841193
result = findsup(types, method_table(interp))
@@ -1200,12 +1209,14 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
12001209
# t, a = ti.parameters[i], argtypes′[i]
12011210
# argtypes′[i] = t ⊑ a ? t : a
12021211
# end
1203-
const_rt, const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false)
1204-
if const_rt !== rt && const_rt rt
1205-
return CallMeta(collect_limitations!(const_rt, sv), InvokeCallInfo(match, const_result))
1206-
else
1207-
return CallMeta(collect_limitations!(rt, sv), InvokeCallInfo(match, nothing))
1212+
const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false)
1213+
if const_result !== nothing
1214+
const_rt, const_result = const_result
1215+
if const_rt !== rt && const_rt rt
1216+
return CallMeta(collect_limitations!(const_rt, sv), InvokeCallInfo(match, const_result))
1217+
end
12081218
end
1219+
return CallMeta(collect_limitations!(rt, sv), InvokeCallInfo(match, nothing))
12091220
end
12101221

12111222
# call where the function is known exactly
@@ -1307,19 +1318,20 @@ end
13071318
function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
13081319
pushfirst!(argtypes, closure.env)
13091320
sig = argtypes_to_type(argtypes)
1310-
(; rt, edge) = result = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, sv)
1321+
(; rt, edge) = result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
13111322
edge !== nothing && add_backedge!(edge, sv)
13121323
tt = closure.typ
1313-
sigT = unwrap_unionall(tt).parameters[1]
1314-
match = MethodMatch(sig, Core.svec(), closure.source::Method, sig <: rewrap_unionall(sigT, tt))
1324+
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
1325+
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
13151326
info = OpaqueClosureCallInfo(match)
13161327
if !result.edgecycle
1317-
const_rettype, const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes,
1328+
const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes,
13181329
match, sv, closure.isva)
1319-
if const_rettype rt
1320-
rt = const_rettype
1321-
end
13221330
if const_result !== nothing
1331+
const_rettype, const_result = const_result
1332+
if const_rettype rt
1333+
rt = const_rettype
1334+
end
13231335
info = ConstCallInfo(info, Union{Nothing,InferenceResult}[const_result])
13241336
end
13251337
end
@@ -1329,7 +1341,7 @@ end
13291341
function most_general_argtypes(closure::PartialOpaque)
13301342
ret = Any[]
13311343
cc = widenconst(closure)
1332-
argt = unwrap_unionall(cc).parameters[1]
1344+
argt = (unwrap_unionall(cc)::DataType).parameters[1]
13331345
if !isa(argt, DataType) || argt.name !== typename(Tuple)
13341346
argt = Tuple
13351347
end
@@ -1344,8 +1356,8 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
13441356
f = argtype_to_function(ft)
13451357
if isa(ft, PartialOpaque)
13461358
return abstract_call_opaque_closure(interp, ft, argtypes[2:end], sv)
1347-
elseif isa(unwrap_unionall(ft), DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure)
1348-
return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false)
1359+
elseif (uft = unwrap_unionall(ft); isa(uft, DataType) && uft.name === typename(Core.OpaqueClosure))
1360+
return CallMeta(rewrap_unionall((uft::DataType).parameters[2], ft), false)
13491361
elseif f === nothing
13501362
# non-constant function, but the number of arguments is known
13511363
# and the ft is not a Builtin or IntrinsicFunction
@@ -1541,12 +1553,12 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15411553
if length(e.args) == 2 && isconcretetype(t) && !ismutabletype(t)
15421554
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
15431555
n = fieldcount(t)
1544-
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val) &&
1545-
let t = t; _all(i->getfield(at.val, i) isa fieldtype(t, i), 1:n); end
1556+
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val::Tuple) &&
1557+
let t = t; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end
15461558
t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, at.val))
1547-
elseif isa(at, PartialStruct) && at Tuple && n == length(at.fields) &&
1548-
let t = t, at = at; _all(i->at.fields[i] fieldtype(t, i), 1:n); end
1549-
t = PartialStruct(t, at.fields)
1559+
elseif isa(at, PartialStruct) && at Tuple && n == length(at.fields::Vector{Any}) &&
1560+
let t = t, at = at; _all(i->(at.fields::Vector{Any})[i] fieldtype(t, i), 1:n); end
1561+
t = PartialStruct(t, at.fields::Vector{Any})
15501562
end
15511563
end
15521564
elseif ehead === :new_opaque_closure
@@ -1594,7 +1606,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15941606
sym = e.args[1]
15951607
t = Bool
15961608
if isa(sym, SlotNumber)
1597-
vtyp = vtypes[slot_id(sym)]
1609+
vtyp = vtypes[slot_id(sym)]::VarState
15981610
if vtyp.typ === Bottom
15991611
t = Const(false) # never assigned previously
16001612
elseif !vtyp.undef
@@ -1609,7 +1621,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
16091621
t = Const(true)
16101622
end
16111623
elseif isa(sym, Expr) && sym.head === :static_parameter
1612-
n = sym.args[1]
1624+
n = sym.args[1]::Int
16131625
if 1 <= n <= length(sv.sptypes)
16141626
spty = sv.sptypes[n]
16151627
if isa(spty, Const)
@@ -1644,7 +1656,7 @@ function abstract_eval_global(M::Module, s::Symbol)
16441656
end
16451657

16461658
function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo)
1647-
typ = src.ssavaluetypes[s.id]
1659+
typ = (src.ssavaluetypes::Vector{Any})[s.id]
16481660
if typ === NOT_FOUND
16491661
return Bottom
16501662
end
@@ -1732,6 +1744,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
17321744
isva = isa(def, Method) && def.isva
17331745
nslots = nargs - isva
17341746
slottypes = frame.slottypes
1747+
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
17351748
while frame.pc´´ <= n
17361749
# make progress on the active ip set
17371750
local pc::Int = frame.pc´´
@@ -1832,7 +1845,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18321845
for (caller, caller_pc) in frame.cycle_backedges
18331846
# notify backedges of updated type information
18341847
typeassert(caller.stmt_types[caller_pc], VarTable) # we must have visited this statement before
1835-
if !(caller.src.ssavaluetypes[caller_pc] === Any)
1848+
if !((caller.src.ssavaluetypes::Vector{Any})[caller_pc] === Any)
18361849
# no reason to revisit if that call-site doesn't affect the final result
18371850
if caller_pc < caller.pc´´
18381851
caller.pc´´ = caller_pc
@@ -1842,6 +1855,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18421855
end
18431856
end
18441857
elseif hd === :enter
1858+
stmt = stmt::Expr
18451859
l = stmt.args[1]::Int
18461860
# propagate type info to exception handler
18471861
old = states[l]
@@ -1857,16 +1871,18 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18571871
elseif hd === :leave
18581872
else
18591873
if hd === :(=)
1874+
stmt = stmt::Expr
18601875
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
18611876
if t === Bottom
18621877
break
18631878
end
1864-
frame.src.ssavaluetypes[pc] = t
1879+
ssavaluetypes[pc] = t
18651880
lhs = stmt.args[1]
18661881
if isa(lhs, SlotNumber)
18671882
changes = StateUpdate(lhs, VarState(t, false), changes, false)
18681883
end
18691884
elseif hd === :method
1885+
stmt = stmt::Expr
18701886
fname = stmt.args[1]
18711887
if isa(fname, SlotNumber)
18721888
changes = StateUpdate(fname, VarState(Any, false), changes, false)
@@ -1881,7 +1897,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18811897
if !isempty(frame.ssavalue_uses[pc])
18821898
record_ssa_assign(pc, t, frame)
18831899
else
1884-
frame.src.ssavaluetypes[pc] = t
1900+
ssavaluetypes[pc] = t
18851901
end
18861902
end
18871903
if isa(changes, StateUpdate)
@@ -1908,7 +1924,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
19081924

19091925
if t === nothing
19101926
# mark other reached expressions as `Any` to indicate they don't throw
1911-
frame.src.ssavaluetypes[pc] = Any
1927+
ssavaluetypes[pc] = Any
19121928
end
19131929

19141930
pc´ > n && break # can't proceed with the fast-path fall-through

base/compiler/inferencestate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
265265
while temp isa UnionAll
266266
temp = temp.body
267267
end
268-
sigtypes = temp.parameters
268+
sigtypes = (temp::DataType).parameters
269269
for j = 1:length(sigtypes)
270270
tj = sigtypes[j]
271271
if isType(tj) && tj.parameters[1] === Pi

base/compiler/ssair/legacy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ function replace_code_newstyle!(ci::CodeInfo, ir::IRCode, nargs::Int)
4747
for metanode in ir.meta
4848
push!(ci.code, metanode)
4949
push!(ci.codelocs, 1)
50-
push!(ci.ssavaluetypes, Any)
50+
push!(ci.ssavaluetypes::Vector{Any}, Any)
5151
push!(ci.ssaflags, 0x00)
5252
end
5353
# Translate BB Edges to statement edges

base/compiler/ssair/passes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1064,7 +1064,7 @@ function type_lift_pass!(ir::IRCode)
10641064
if haskey(processed, id)
10651065
val = processed[id]
10661066
else
1067-
push!(worklist, (id, up_id, new_phi, i))
1067+
push!(worklist, (id, up_id, new_phi::SSAValue, i))
10681068
continue
10691069
end
10701070
else

base/compiler/ssair/slot2ssa.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse, narg
871871
changed = false
872872
for new_idx in type_refine_phi
873873
node = new_nodes.stmts[new_idx]
874-
new_typ = recompute_type(node[:inst], ci, ir, ir.sptypes, slottypes)
874+
new_typ = recompute_type(node[:inst]::Union{PhiNode,PhiCNode}, ci, ir, ir.sptypes, slottypes)
875875
if !(node[:type] new_typ) || !(new_typ node[:type])
876876
node[:type] = new_typ
877877
changed = true

base/compiler/tfuncs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,7 +1627,7 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
16271627
if length(argtypes) - 1 == tf[2]
16281628
argtypes = argtypes[1:end-1]
16291629
else
1630-
vatype = argtypes[end]
1630+
vatype = argtypes[end]::Core.TypeofVararg
16311631
argtypes = argtypes[1:end-1]
16321632
while length(argtypes) < tf[1]
16331633
push!(argtypes, unwrapva(vatype))
@@ -1733,7 +1733,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
17331733
aft = argtypes[2]
17341734
if isa(aft, Const) || (isType(aft) && !has_free_typevars(aft)) ||
17351735
(isconcretetype(aft) && !(aft <: Builtin))
1736-
af_argtype = isa(tt, Const) ? tt.val : tt.parameters[1]
1736+
af_argtype = isa(tt, Const) ? tt.val : (tt::DataType).parameters[1]
17371737
if isa(af_argtype, DataType) && af_argtype <: Tuple
17381738
argtypes_vec = Any[aft, af_argtype.parameters...]
17391739
if contains_is(argtypes_vec, Union{})

0 commit comments

Comments
 (0)