Skip to content

Commit a713a8a

Browse files
vtjnashstaticfloat
authored andcommitted
inference: make Limited tracking part of the type lattice (#39116)
This helps refine our knowledge of the `[limited]` flag setting, which previously would always exclude a result from the cache when hitting a cycle. However, we really only need to exclude a result if the result might be dependent on that flag setting. That makes this formally part of the lattice, though can be annoying to work with yet another wrapper, so we try to add/remove it late/early to propagate it when necessary. (cherry picked from commit 5f10eb9)
1 parent 05f0546 commit a713a8a

File tree

6 files changed

+235
-106
lines changed

6 files changed

+235
-106
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ const _REF_NAME = Ref.body.name
1212
# logic #
1313
#########
1414

15-
# see if the inference result might affect the final answer
16-
call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) =
17-
isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[pc])
15+
# See if the inference result of the current statement's result value might affect
16+
# the final answer for the method (aside from optimization potential and exceptions).
17+
# To do that, we need to check both for slot assignment and SSA usage.
18+
call_result_unused(frame::InferenceState) =
19+
isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[frame.currpc])
1820

1921
# check if this return type is improvable (i.e. whether it's possible that with
2022
# more information, we might get a more precise type)
@@ -192,6 +194,16 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
192194
end
193195
end
194196
#print("=> ", rettype, "\n")
197+
if rettype isa LimitedAccuracy
198+
union!(sv.pclimitations, rettype.causes)
199+
rettype = rettype.typ
200+
end
201+
if !isempty(sv.pclimitations) # remove self, if present
202+
delete!(sv.pclimitations, sv)
203+
for caller in sv.callers_in_cycle
204+
delete!(sv.pclimitations, caller)
205+
end
206+
end
195207
return CallMeta(rettype, info)
196208
end
197209

@@ -313,7 +325,6 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
313325
inf_result = InferenceResult(mi, argtypes)
314326
frame = InferenceState(inf_result, #=cache=#false, interp)
315327
frame === nothing && return Any # this is probably a bad generated function (unsound), but just ignore it
316-
frame.limited = true
317328
frame.parent = sv
318329
push!(inf_cache, inf_result)
319330
typeinf(interp, frame) || return Any
@@ -394,7 +405,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
394405
parent = parent::InferenceState
395406
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
396407
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
397-
if (parent.cached || parent.limited) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
408+
if (parent.cached || parent.parent !== nothing) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
398409
topmost = infstate
399410
edgecycle = true
400411
end
@@ -443,7 +454,8 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
443454
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
444455
return Any, true, nothing
445456
end
446-
poison_callstack(sv, topmost::InferenceState, true)
457+
topmost = topmost::InferenceState
458+
poison_callstack(sv, topmost.parent === nothing ? topmost : topmost.parent)
447459
sig = newsig
448460
sparams = svec()
449461
end
@@ -1129,7 +1141,12 @@ function abstract_eval_value(interp::AbstractInterpreter, @nospecialize(e), vtyp
11291141
if isa(e, Expr)
11301142
return abstract_eval_value_expr(interp, e, vtypes, sv)
11311143
else
1132-
return abstract_eval_special_value(interp, e, vtypes, sv)
1144+
typ = abstract_eval_special_value(interp, e, vtypes, sv)
1145+
if typ isa LimitedAccuracy
1146+
union!(sv.pclimitations, typ.causes)
1147+
typ = typ.typ
1148+
end
1149+
return typ
11331150
end
11341151
end
11351152

@@ -1252,13 +1269,21 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
12521269
end
12531270
end
12541271
else
1255-
return abstract_eval_value_expr(interp, e, vtypes, sv)
1272+
t = abstract_eval_value_expr(interp, e, vtypes, sv)
12561273
end
12571274
@assert !isa(t, TypeVar)
12581275
if isa(t, DataType) && isdefined(t, :instance)
12591276
# replace singleton types with their equivalent Const object
12601277
t = Const(t.instance)
12611278
end
1279+
if !isempty(sv.pclimitations)
1280+
if t isa Const || t === Union{}
1281+
empty!(sv.pclimitations)
1282+
else
1283+
t = LimitedAccuracy(t, sv.pclimitations)
1284+
sv.pclimitations = IdSet{InferenceState}()
1285+
end
1286+
end
12621287
return t
12631288
end
12641289

@@ -1313,10 +1338,18 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
13131338
elseif isa(stmt, GotoIfNot)
13141339
condt = abstract_eval_value(interp, stmt.cond, s[pc], frame)
13151340
if condt === Bottom
1341+
empty!(frame.pclimitations)
13161342
break
13171343
end
13181344
condval = maybe_extract_const_bool(condt)
13191345
l = stmt.dest::Int
1346+
if !isempty(frame.pclimitations)
1347+
# we can't model the possible effect of control
1348+
# dependencies on the return value, so we propagate it
1349+
# directly to all the return values (unless we error first)
1350+
condval isa Bool || union!(frame.limitations, frame.pclimitations)
1351+
empty!(frame.pclimitations)
1352+
end
13201353
# constant conditions
13211354
if condval === true
13221355
elseif condval === false
@@ -1351,6 +1384,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
13511384
# and is valid inter-procedurally
13521385
rt = widenconst(rt)
13531386
end
1387+
# copy limitations to return value
1388+
if !isempty(frame.pclimitations)
1389+
union!(frame.limitations, frame.pclimitations)
1390+
empty!(frame.pclimitations)
1391+
end
1392+
if !isempty(frame.limitations)
1393+
rt = LimitedAccuracy(rt, copy(frame.limitations))
1394+
end
13541395
if tchanged(rt, frame.bestguess)
13551396
# new (wider) return type for frame
13561397
frame.bestguess = tmerge(frame.bestguess, rt)
@@ -1425,6 +1466,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
14251466
end
14261467
end
14271468

1469+
@assert isempty(frame.pclimitations) "unhandled LimitedAccuracy"
1470+
14281471
if t === nothing
14291472
# mark other reached expressions as `Any` to indicate they don't throw
14301473
frame.src.ssavaluetypes[pc] = Any

base/compiler/inferencestate.jl

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ mutable struct InferenceState
1010
slottypes::Vector{Any}
1111
mod::Module
1212
currpc::LineNum
13+
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
14+
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return
1315

1416
# info on the state of inference and the linfo
1517
src::CodeInfo
@@ -39,7 +41,6 @@ mutable struct InferenceState
3941

4042
# TODO: move these to InferenceResult / Params?
4143
cached::Bool
42-
limited::Bool
4344
inferred::Bool
4445
dont_work_on_me::Bool
4546

@@ -105,6 +106,7 @@ mutable struct InferenceState
105106
frame = new(
106107
InferenceParams(interp), result, linfo,
107108
sp, slottypes, inmodule, 0,
109+
IdSet{InferenceState}(), IdSet{InferenceState}(),
108110
src, get_world_counter(interp), valid_worlds,
109111
nargs, s_types, s_edges, stmt_info,
110112
Union{}, W, 1, n,
@@ -113,7 +115,7 @@ mutable struct InferenceState
113115
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
114116
Vector{InferenceState}(), # callers_in_cycle
115117
#=parent=#nothing,
116-
cached, false, false, false,
118+
cached, false, false,
117119
CachedMethodTable(method_table(interp)),
118120
interp)
119121
result.result = frame
@@ -261,37 +263,13 @@ function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::Infe
261263
nothing
262264
end
263265

264-
function poison_callstack(infstate::InferenceState, topmost::InferenceState, poison_topmost::Bool)
265-
poison_topmost && (topmost = topmost.parent)
266-
while !(infstate === topmost)
267-
if call_result_unused(infstate)
268-
# If we won't propagate the result any further (since it's typically unused),
269-
# it's OK that we keep and cache the "limited" result in the parents
270-
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
271-
# TODO: we might be able to halt progress much more strongly here,
272-
# since now we know we won't be able to keep anything much that we learned.
273-
# We were mainly only here to compute the calling convention return type,
274-
# but in most situations now, we are unlikely to be able to use that information.
275-
break
276-
end
277-
infstate.limited = true
278-
for infstate_cycle in infstate.callers_in_cycle
279-
infstate_cycle.limited = true
280-
end
281-
infstate = infstate.parent
282-
infstate === nothing && return
283-
end
284-
end
285-
286266
function print_callstack(sv::InferenceState)
287267
while sv !== nothing
288268
print(sv.linfo)
289-
sv.limited && print(" [limited]")
290269
!sv.cached && print(" [uncached]")
291270
println()
292271
for cycle in sv.callers_in_cycle
293272
print(' ', cycle.linfo)
294-
cycle.limited && print(" [limited]")
295273
println()
296274
end
297275
sv = sv.parent

base/compiler/tfuncs.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,10 +1586,14 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
15861586
# output was computed to be constant
15871587
return Const(typeof(rt.val))
15881588
else
1589+
inaccurate = nothing
1590+
rt isa LimitedAccuracy && (inaccurate = rt.causes; rt = rt.typ)
15891591
rt = widenconst(rt)
15901592
if hasuniquerep(rt) || rt === Bottom
15911593
# output type was known for certain
15921594
return Const(rt)
1595+
elseif inaccurate !== nothing
1596+
return LimitedAccuracy(Type{<:rt}, inaccurate)
15931597
elseif (isa(tt, Const) || isconstType(tt)) &&
15941598
(isa(aft, Const) || isconstType(aft))
15951599
# input arguments were known for certain

0 commit comments

Comments
 (0)