Skip to content

Commit 998951e

Browse files
committed
Slightly refactor recursion detection NFC.
I'm about to suggest some pretty invasive changes to the recursion detection code, so I figured as a first step, I'd refactor it to make it a bit easier to understand, since I always get confused by this code,
1 parent 1006d5f commit 998951e

File tree

2 files changed

+54
-55
lines changed

2 files changed

+54
-55
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 31 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ function add_call_backedges!(interp::AbstractInterpreter,
310310
end
311311

312312
const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result."
313+
const RECURSION_MSG = "Bounded recursion detected. Call was widened to force convergence."
313314

314315
function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
315316
if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
@@ -321,18 +322,22 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
321322
# look through the parents list to see if there's a call to the same method
322323
# and from the same method.
323324
# Returns the topmost occurrence of that repeated edge.
324-
cyclei = 0
325-
infstate = sv
326325
edgecycle = false
327326
# The `method_for_inference_heuristics` will expand the given method's generator if
328327
# necessary in order to retrieve this field from the generated `CodeInfo`, if it exists.
329328
# The other `CodeInfo`s we inspect will already have this field inflated, so we just
330329
# access it directly instead (to avoid regeneration).
331-
method2 = method_for_inference_heuristics(method, sig, sparams) # Union{Method, Nothing}
330+
callee_method2 = method_for_inference_heuristics(method, sig, sparams) # Union{Method, Nothing}
332331
sv_method2 = sv.src.method_for_inference_limit_heuristics # limit only if user token match
333332
sv_method2 isa Method || (sv_method2 = nothing) # Union{Method, Nothing}
334-
while !(infstate === nothing)
335-
infstate = infstate::InferenceState
333+
334+
function matches_sv(parent::InferenceState)
335+
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
336+
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
337+
return parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
338+
end
339+
340+
for infstate in InfStackUnwind(sv)
336341
if method === infstate.linfo.def
337342
if infstate.linfo.specTypes == sig
338343
# avoid widening when detecting self-recursion
@@ -349,52 +354,33 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
349354
edgecycle = true
350355
break
351356
end
357+
topmost === nothing || continue
352358
inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match
353359
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
354-
if topmost === nothing && method2 === inf_method2
355-
if hardlimit
356-
topmost = infstate
357-
edgecycle = true
358-
else
360+
if callee_method2 === inf_method2
361+
if !hardlimit
359362
# if this is a soft limit,
360363
# also inspect the parent of this edge,
361364
# to see if they are the same Method as sv
362365
# in which case we'll need to ensure it is convergent
363366
# otherwise, we don't
364-
for parent in infstate.callers_in_cycle
365-
# check in the cycle list first
366-
# all items in here are mutual parents of all others
367-
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
368-
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
369-
if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
370-
topmost = infstate
371-
edgecycle = true
372-
break
373-
end
374-
end
375-
let parent = infstate.parent
376-
# then check the parent link
377-
if topmost === nothing && parent !== nothing
367+
368+
# check in the cycle list first
369+
# all items in here are mutual parents of all others
370+
if !_any(matches_sv, infstate.callers_in_cycle)
371+
let parent = infstate.parent
372+
parent !== nothing || continue
378373
parent = parent::InferenceState
379-
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
380-
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
381-
if (parent.cached || parent.parent !== nothing) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
382-
topmost = infstate
383-
edgecycle = true
384-
end
374+
(parent.cached || parent.parent !== nothing) || continue
375+
matches_sv(parent) || continue
385376
end
386377
end
387378
end
379+
380+
topmost = infstate
381+
edgecycle = true
388382
end
389383
end
390-
# iterate through the cycle before walking to the parent
391-
if cyclei < length(infstate.callers_in_cycle)
392-
cyclei += 1
393-
infstate = infstate.callers_in_cycle[cyclei]
394-
else
395-
cyclei = 0
396-
infstate = infstate.parent
397-
end
398384
end
399385

400386
if !(topmost === nothing)
@@ -427,6 +413,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
427413
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
428414
return Any, true, nothing
429415
end
416+
add_remark!(interp, sv, RECURSION_MSG)
430417
topmost = topmost::InferenceState
431418
parentframe = topmost.parent
432419
poison_callstack(sv, parentframe === nothing ? topmost : parentframe)
@@ -478,24 +465,13 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
478465
inf_cache = get_inference_cache(interp)
479466
inf_result = cache_lookup(mi, argtypes, inf_cache)
480467
if inf_result === nothing
481-
if edgecycle
482-
# if there might be a cycle, check to make sure we don't end up
483-
# calling ourselves here.
484-
infstate = sv
485-
cyclei = 0
486-
while !(infstate === nothing)
487-
if match.method === infstate.linfo.def && any(infstate.result.overridden_by_const)
488-
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
489-
return Any, nothing
490-
end
491-
if cyclei < length(infstate.callers_in_cycle)
492-
cyclei += 1
493-
infstate = infstate.callers_in_cycle[cyclei]
494-
else
495-
cyclei = 0
496-
infstate = infstate.parent
497-
end
468+
# if there might be a cycle, check to make sure we don't end up
469+
# calling ourselves here.
470+
if edgecycle && _any(InfStackUnwind(sv)) do infstate
471+
return match.method === infstate.linfo.def && any(infstate.result.overridden_by_const)
498472
end
473+
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
474+
return Any, nothing
499475
end
500476
inf_result = InferenceResult(mi, argtypes, va_override)
501477
frame = InferenceState(inf_result, #=cache=#false, interp)

base/compiler/inferencestate.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,29 @@ mutable struct InferenceState
124124
end
125125
end
126126

127+
"""
128+
Iterate through all callers of the given InferenceState in the abstract
129+
interpretation stack (including the given InferenceState itself), vising
130+
children before their parents (i.e. ascending the tree from the given
131+
InferenceState). Note that cycles may be visited in any order.
132+
"""
133+
struct InfStackUnwind
134+
inf::InferenceState
135+
end
136+
iterate(unw::InfStackUnwind) = (unw.inf, (unw.inf, 0))
137+
function iterate(unw::InfStackUnwind, (infstate, cyclei)::Tuple{InferenceState, Int})
138+
# iterate through the cycle before walking to the parent
139+
if cyclei < length(infstate.callers_in_cycle)
140+
cyclei += 1
141+
infstate = infstate.callers_in_cycle[cyclei]
142+
else
143+
cyclei = 0
144+
infstate = infstate.parent
145+
end
146+
infstate === nothing && return nothing
147+
(infstate::InferenceState, (infstate, cyclei))
148+
end
149+
127150
method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table
128151

129152
function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)

0 commit comments

Comments
 (0)