Skip to content

Commit 88c0e25

Browse files
authored
Add PartialOpaque lattice element for OpaqueClosure (#39512)
This adds a lattice element for tracking OpaqueClosures in inference, but does not yet do anything with it. The reason I'm separating this out is that just the introduction of the lattice element raises some tricky issues. In particular, the lattice element refers back to the OpaqueClosure method, which we currently don't support in the serializer. I played with several ways of adding support for that, but in the end it all ended up super complicated for questionable benefit, so in this PR, CodeInstances that get inferred to `PartialOpaque` get omitted during serialization (i.e. they will be reinfered upon loading the .ji).
1 parent 37aea06 commit 88c0e25

20 files changed

+301
-63
lines changed

base/boot.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecial
424424
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))
425425
eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v))))
426426
eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))))
427+
eval(Core, :(PartialOpaque(@nospecialize(typ), @nospecialize(env), isva::Bool, parent::MethodInstance, source::Method) = $(Expr(:new, :PartialOpaque, :typ, :env, :isva, :parent, :source))))
427428
eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) =
428429
$(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers))))
429430

base/compiler/abstractinterpretation.jl

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ function const_prop_profitable(@nospecialize(arg))
227227
const_prop_profitable(b) && return true
228228
end
229229
end
230+
isa(arg, PartialOpaque) && return true
230231
isa(arg, Const) || return true
231232
val = arg.val
232233
# don't consider mutable values or Strings useful constants
@@ -268,7 +269,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
268269
# see if any or all of the arguments are constant and propagating constants may be worthwhile
269270
for a in argtypes
270271
a = widenconditional(a)
271-
if allconst && !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct)
272+
if allconst && !isa(a, Const) && !isconstType(a) && !isa(a, PartialStruct) && !isa(a, PartialOpaque)
272273
allconst = false
273274
end
274275
if !haveconst && has_nontrivial_const_info(a) && const_prop_profitable(a)
@@ -688,7 +689,7 @@ end
688689
function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospecialize(aft), aargtypes::Vector{Any}, sv::InferenceState,
689690
max_methods::Int = InferenceParams(interp).MAX_METHODS)
690691
aftw = widenconst(aft)
691-
if !isa(aft, Const) && (!isType(aftw) || has_free_typevars(aftw))
692+
if !isa(aft, Const) && !isa(aft, PartialOpaque) && (!isType(aftw) || has_free_typevars(aftw))
692693
if !isconcretetype(aftw) || (aftw <: Builtin)
693694
add_remark!(interp, sv, "Core._apply_iterate called on a function of a non-concrete type")
694695
# bail now, since it seems unlikely that abstract_call will be able to do any better after splitting
@@ -1057,6 +1058,20 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
10571058
return abstract_call_gf_by_type(interp, f, argtypes, atype, sv, max_methods)
10581059
end
10591060

1061+
function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
1062+
return CallMeta(Any, nothing)
1063+
end
1064+
1065+
function most_general_argtypes(closure::PartialOpaque)
1066+
ret = Any[]
1067+
cc = widenconst(closure)
1068+
argt = unwrap_unionall(cc).parameters[1]
1069+
if !isa(argt, DataType) || argt.name !== typename(Tuple)
1070+
argt = Tuple
1071+
end
1072+
return most_general_argtypes(closure.source, argt, closure.isva)
1073+
end
1074+
10601075
# call where the function is any lattice element
10611076
function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any},
10621077
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
@@ -1068,10 +1083,14 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
10681083
f = ft.parameters[1]
10691084
elseif isa(ft, DataType) && isdefined(ft, :instance)
10701085
f = ft.instance
1086+
elseif isa(ft, PartialOpaque)
1087+
return abstract_call_opaque_closure(interp, ft, argtypes, sv)
1088+
elseif isa(unwrap_unionall(ft), DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure)
1089+
return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false)
10711090
else
10721091
# non-constant function, but the number of arguments is known
10731092
# and the ft is not a Builtin or IntrinsicFunction
1074-
if typeintersect(widenconst(ft), Builtin) != Union{}
1093+
if typeintersect(widenconst(ft), Union{Builtin, Core.OpaqueClosure}) != Union{}
10751094
add_remark!(interp, sv, "Could not identify method table for call")
10761095
return CallMeta(Any, false)
10771096
end
@@ -1173,25 +1192,34 @@ function abstract_eval_value(interp::AbstractInterpreter, @nospecialize(e), vtyp
11731192
end
11741193
end
11751194

1195+
function collect_argtypes(interp::AbstractInterpreter, ea::Vector{Any}, vtypes::VarTable, sv::InferenceState)
1196+
n = length(ea)
1197+
argtypes = Vector{Any}(undef, n)
1198+
@inbounds for i = 1:n
1199+
ai = abstract_eval_value(interp, ea[i], vtypes, sv)
1200+
if bail_out_statement(interp, ai, sv)
1201+
return Bottom
1202+
end
1203+
argtypes[i] = ai
1204+
end
1205+
return argtypes
1206+
end
1207+
11761208
function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e), vtypes::VarTable, sv::InferenceState)
11771209
if !isa(e, Expr)
11781210
return abstract_eval_special_value(interp, e, vtypes, sv)
11791211
end
11801212
e = e::Expr
11811213
if e.head === :call
11821214
ea = e.args
1183-
n = length(ea)
1184-
argtypes = Vector{Any}(undef, n)
1185-
@inbounds for i = 1:n
1186-
ai = abstract_eval_value(interp, ea[i], vtypes, sv)
1187-
if bail_out_statement(interp, ai, sv)
1188-
return Bottom
1189-
end
1190-
argtypes[i] = ai
1215+
argtypes = collect_argtypes(interp, ea, vtypes, sv)
1216+
if argtypes === Bottom
1217+
t = Bottom
1218+
else
1219+
callinfo = abstract_call(interp, ea, argtypes, sv)
1220+
sv.stmt_info[sv.currpc] = callinfo.info
1221+
t = callinfo.rt
11911222
end
1192-
callinfo = abstract_call(interp, ea, argtypes, sv)
1193-
sv.stmt_info[sv.currpc] = callinfo.info
1194-
t = callinfo.rt
11951223
elseif e.head === :new
11961224
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
11971225
if isconcretetype(t) && !t.mutable
@@ -1242,6 +1270,24 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
12421270
t = PartialStruct(t, at.fields)
12431271
end
12441272
end
1273+
elseif e.head === :new_opaque_closure
1274+
t = Union{}
1275+
if length(e.args) >= 5
1276+
ea = e.args
1277+
argtypes = collect_argtypes(interp, ea, vtypes, sv)
1278+
if argtypes === Bottom
1279+
t = Bottom
1280+
else
1281+
t = _opaque_closure_tfunc(argtypes[1], argtypes[2], argtypes[3],
1282+
argtypes[4], argtypes[5], argtypes[6:end], sv.linfo)
1283+
if isa(t, PartialOpaque)
1284+
# Infer this now so that the specialization is available to
1285+
# optimization.
1286+
abstract_call_opaque_closure(interp, t,
1287+
most_general_argtypes(t), sv)
1288+
end
1289+
end
1290+
end
12451291
elseif e.head === :foreigncall
12461292
abstract_eval_value(interp, e.args[1], vtypes, sv)
12471293
t = sp_type_rewrap(e.args[2], sv.linfo, true)
@@ -1404,7 +1450,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14041450
elseif isa(stmt, ReturnNode)
14051451
pc´ = n + 1
14061452
rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame))
1407-
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct)
1453+
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) && !isa(rt, PartialOpaque)
14081454
# only propagate information we know we can store
14091455
# and is valid inter-procedurally
14101456
rt = widenconst(rt)

base/compiler/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Core.Intrinsics, Core.IR
66

77
import Core: print, println, show, write, unsafe_write, stdout, stderr,
88
_apply_iterate, svec, apply_type, Builtin, IntrinsicFunction,
9-
MethodInstance, CodeInstance, MethodMatch
9+
MethodInstance, CodeInstance, MethodMatch, PartialOpaque
1010

1111
const getproperty = Core.getfield
1212
const setproperty! = Core.setfield!

base/compiler/inferenceresult.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,20 @@ function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector)
4545
return cache_argtypes, overridden_by_const
4646
end
4747

48-
function matching_cache_argtypes(linfo::MethodInstance, ::Nothing)
49-
toplevel = !isa(linfo.def, Method)
50-
linfo_argtypes = Any[unwrap_unionall(linfo.specTypes).parameters...]
51-
nargs::Int = toplevel ? 0 : linfo.def.nargs
48+
function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(specTypes),
49+
isva::Bool)
50+
toplevel = method === nothing
51+
linfo_argtypes = Any[unwrap_unionall(specTypes).parameters...]
52+
nargs::Int = toplevel ? 0 : method.nargs
53+
if !toplevel && method.is_for_opaque_closure
54+
# For opaque closure, the closure environment is processed elsewhere
55+
nargs -= 1
56+
end
5257
cache_argtypes = Vector{Any}(undef, nargs)
5358
# First, if we're dealing with a varargs method, then we set the last element of `args`
5459
# to the appropriate `Tuple` type or `PartialStruct` instance.
55-
if !toplevel && linfo.def.isva
56-
if linfo.specTypes == Tuple
60+
if !toplevel && isva
61+
if specTypes == Tuple
5762
if nargs > 1
5863
linfo_argtypes = svec(Any[Any for i = 1:(nargs - 1)]..., Tuple.parameters[1])
5964
end
@@ -63,7 +68,7 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing)
6368
if nargs > linfo_argtypes_length
6469
va = linfo_argtypes[linfo_argtypes_length]
6570
if isvarargtype(va)
66-
new_va = rewrap_unionall(unconstrain_vararg_length(va), linfo.specTypes)
71+
new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes)
6772
vargtype_elements = Any[new_va]
6873
vargtype = Tuple{new_va}
6974
else
@@ -74,7 +79,7 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing)
7479
vargtype_elements = Any[]
7580
for p in linfo_argtypes[nargs:linfo_argtypes_length]
7681
p = isvarargtype(p) ? unconstrain_vararg_length(p) : p
77-
push!(vargtype_elements, rewrap(p, linfo.specTypes))
82+
push!(vargtype_elements, rewrap(p, specTypes))
7883
end
7984
for i in 1:length(vargtype_elements)
8085
atyp = vargtype_elements[i]
@@ -115,7 +120,7 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing)
115120
elseif isconstType(atyp)
116121
atyp = Const(atyp.parameters[1])
117122
else
118-
atyp = rewrap(atyp, linfo.specTypes)
123+
atyp = rewrap(atyp, specTypes)
119124
end
120125
i == n && (lastatype = atyp)
121126
cache_argtypes[i] = atyp
@@ -126,6 +131,13 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing)
126131
else
127132
@assert nargs == 0 "invalid specialization of method" # wrong number of arguments
128133
end
134+
cache_argtypes
135+
end
136+
137+
function matching_cache_argtypes(linfo::MethodInstance, ::Nothing)
138+
mthd = isa(linfo.def, Method) ? linfo.def::Method : nothing
139+
cache_argtypes = most_general_argtypes(mthd, linfo.specTypes, isa(mthd, Method) ?
140+
mthd.isva : false)
129141
return cache_argtypes, falses(length(cache_argtypes))
130142
end
131143

base/compiler/tfuncs.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1362,6 +1362,27 @@ add_tfunc(arrayref, 3, INT_INF, arrayref_tfunc, 20)
13621362
add_tfunc(const_arrayref, 3, INT_INF, arrayref_tfunc, 20)
13631363
add_tfunc(arrayset, 4, INT_INF, (@nospecialize(boundscheck), @nospecialize(a), @nospecialize(v), @nospecialize i...)->a, 20)
13641364

1365+
function _opaque_closure_tfunc(@nospecialize(arg), @nospecialize(isva),
1366+
@nospecialize(lb), @nospecialize(ub), @nospecialize(source), env::Vector{Any},
1367+
linfo::MethodInstance)
1368+
1369+
argt, argt_exact = instanceof_tfunc(arg)
1370+
lbt, lb_exact = instanceof_tfunc(lb)
1371+
if !lb_exact
1372+
lbt = Union{}
1373+
end
1374+
1375+
ubt, ub_exact = instanceof_tfunc(ub)
1376+
1377+
t = (argt_exact ? Core.OpaqueClosure{argt, T} : Core.OpaqueClosure{<:argt, T}) where T
1378+
t = lbt == ubt ? t{ubt} : (t{T} where lbt <: T <: ubt)
1379+
1380+
(isa(source, Const) && isa(source.val, Method)) || return t
1381+
(isa(isva, Const) && isa(isva.val, Bool)) || return t
1382+
1383+
return PartialOpaque(t, tuple_tfunc(env), isva.val, linfo, source.val)
1384+
end
1385+
13651386
function array_type_undefable(@nospecialize(a))
13661387
if isa(a, Union)
13671388
return array_type_undefable(a.a) || array_type_undefable(a.b)

base/compiler/typeinfer.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ function CodeInstance(result::InferenceResult, @nospecialize(inferred_result::An
289289
if isa(result_type, Const)
290290
rettype_const = result_type.val
291291
const_flags = 0x2
292+
elseif isa(result_type, PartialOpaque)
293+
rettype_const = result_type
294+
const_flags = 0x2
292295
elseif isconstType(result_type)
293296
rettype_const = result_type.parameters[1]
294297
const_flags = 0x2
@@ -773,6 +776,8 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
773776
if isdefined(code, :rettype_const)
774777
if isa(code.rettype_const, Vector{Any}) && !(Vector{Any} <: code.rettype)
775778
return PartialStruct(code.rettype, code.rettype_const), mi
779+
elseif code.rettype <: Core.OpaqueClosure && isa(code.rettype_const, PartialOpaque)
780+
return code.rettype_const, mi
776781
else
777782
return Const(code.rettype_const), mi
778783
end

base/compiler/typelattice.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ function ⊑(@nospecialize(a), @nospecialize(b))
185185
end
186186
return false
187187
end
188+
if isa(a, PartialOpaque)
189+
if isa(b, PartialOpaque)
190+
(a.parent === b.parent && a.source === b.source) || return false
191+
return (widenconst(a) <: widenconst(b)) &&
192+
(a.env, b.env)
193+
end
194+
return widenconst(a) b
195+
end
188196
if isa(a, Const)
189197
if isa(b, Const)
190198
return a.val === b.val
@@ -223,6 +231,13 @@ function is_lattice_equal(@nospecialize(a), @nospecialize(b))
223231
isa(b, PartialStruct) && return false
224232
a isa Const && return false
225233
b isa Const && return false
234+
if isa(a, PartialOpaque)
235+
isa(b, PartialOpaque) || return false
236+
widenconst(a) == widenconst(b) || return false
237+
a.source === b.source || return false
238+
a.parent === b.parent || return false
239+
return is_lattice_equal(a.env, b.env)
240+
end
226241
return a b && b a
227242
end
228243

@@ -240,6 +255,7 @@ end
240255
widenconst(m::MaybeUndef) = widenconst(m.typ)
241256
widenconst(c::PartialTypeVar) = TypeVar
242257
widenconst(t::PartialStruct) = t.typ
258+
widenconst(t::PartialOpaque) = t.typ
243259
widenconst(t::Type) = t
244260
widenconst(t::TypeVar) = t
245261
widenconst(t::Core.TypeofVararg) = t

base/compiler/typelimits.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,15 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
368368
return anyconst ? PartialStruct(widenconst(typea), fields) :
369369
widenconst(typea)
370370
end
371+
if isa(typea, PartialOpaque) && isa(typeb, PartialOpaque) && widenconst(typea) == widenconst(typeb)
372+
if !(typea.source === typeb.source &&
373+
typea.isva === typeb.isva &&
374+
typea.parent === typeb.parent)
375+
return widenconst(typea)
376+
end
377+
return PartialOpaque(typea.typ, tmerge(typea.env, typeb.env),
378+
typea.isva, typea.parent, typea.source)
379+
end
371380
# no special type-inference lattice, join the types
372381
typea, typeb = widenconst(typea), widenconst(typeb)
373382
if !isa(typea, Type) || !isa(typeb, Type)

base/compiler/typeutils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ end
3434

3535
function has_nontrivial_const_info(@nospecialize t)
3636
isa(t, PartialStruct) && return true
37+
isa(t, PartialOpaque) && return true
3738
isa(t, Const) || return false
3839
val = t.val
3940
return !isdefined(typeof(val), :instance) && !(isa(val, Type) && hasuniquerep(val))

base/compiler/utilities.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,20 @@ function invoke_api(li::CodeInstance)
104104
return ccall(:jl_invoke_api, Cint, (Any,), li)
105105
end
106106

107-
function get_staged(li::MethodInstance)
108-
may_invoke_generator(li) || return nothing
107+
function has_opaque_closure(c::CodeInfo)
108+
for i = 1:length(c.code)
109+
stmt = c.code[i]
110+
(isa(stmt, Expr) && stmt.head === :new_opaque_closure) && return true
111+
end
112+
return false
113+
end
114+
115+
function get_staged(mi::MethodInstance)
116+
may_invoke_generator(mi) || return nothing
109117
try
110118
# user code might throw errors – ignore them
111-
return ccall(:jl_code_for_staged, Any, (Any,), li)::CodeInfo
119+
ci = ccall(:jl_code_for_staged, Any, (Any,), mi)::CodeInfo
120+
return ci
112121
catch
113122
return nothing
114123
end

0 commit comments

Comments
 (0)