Skip to content

Commit b50595d

Browse files
authored
Merge pull request JuliaLang#40088 from JuliaLang/kf/methodrecurserel
inference: Add a hook for users to be able to specify custom recursion relations
2 parents efad4e3 + 73622ed commit b50595d

File tree

7 files changed

+116
-71
lines changed

7 files changed

+116
-71
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 65 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ function add_call_backedges!(interp::AbstractInterpreter,
310310
end
311311

312312
const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result."
313+
const RECURSION_MSG = "Bounded recursion detected. Call was widened to force convergence."
313314

314315
function abstract_call_method(interp::AbstractInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
315316
if method.name === :depwarn && isdefined(Main, :Base) && method.module === Main.Base
@@ -321,18 +322,57 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
321322
# look through the parents list to see if there's a call to the same method
322323
# and from the same method.
323324
# Returns the topmost occurrence of that repeated edge.
324-
cyclei = 0
325-
infstate = sv
326325
edgecycle = false
327326
# The `method_for_inference_heuristics` will expand the given method's generator if
328327
# necessary in order to retrieve this field from the generated `CodeInfo`, if it exists.
329328
# The other `CodeInfo`s we inspect will already have this field inflated, so we just
330329
# access it directly instead (to avoid regeneration).
331-
method2 = method_for_inference_heuristics(method, sig, sparams) # Union{Method, Nothing}
330+
callee_method2 = method_for_inference_heuristics(method, sig, sparams) # Union{Method, Nothing}
332331
sv_method2 = sv.src.method_for_inference_limit_heuristics # limit only if user token match
333332
sv_method2 isa Method || (sv_method2 = nothing) # Union{Method, Nothing}
334-
while !(infstate === nothing)
335-
infstate = infstate::InferenceState
333+
334+
function matches_sv(parent::InferenceState)
335+
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
336+
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
337+
return parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
338+
end
339+
340+
function edge_matches_sv(frame::InferenceState)
341+
inf_method2 = frame.src.method_for_inference_limit_heuristics # limit only if user token match
342+
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
343+
if callee_method2 !== inf_method2
344+
return false
345+
end
346+
if !hardlimit
347+
# if this is a soft limit,
348+
# also inspect the parent of this edge,
349+
# to see if they are the same Method as sv
350+
# in which case we'll need to ensure it is convergent
351+
# otherwise, we don't
352+
353+
# check in the cycle list first
354+
# all items in here are mutual parents of all others
355+
if !_any(matches_sv, frame.callers_in_cycle)
356+
let parent = frame.parent
357+
parent !== nothing || return false
358+
parent = parent::InferenceState
359+
(parent.cached || parent.parent !== nothing) || return false
360+
matches_sv(parent) || return false
361+
end
362+
end
363+
364+
# If the method defines a recursion relation, give it a chance
365+
# to tell us that this recursion is actually ok.
366+
if isdefined(method, :recursion_relation)
367+
if Core._apply_pure(method.recursion_relation, Any[method, callee_method2, sig, frame.linfo.specTypes])
368+
return false
369+
end
370+
end
371+
end
372+
return true
373+
end
374+
375+
for infstate in InfStackUnwind(sv)
336376
if method === infstate.linfo.def
337377
if infstate.linfo.specTypes == sig
338378
# avoid widening when detecting self-recursion
@@ -349,60 +389,20 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
349389
edgecycle = true
350390
break
351391
end
352-
inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match
353-
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
354-
if topmost === nothing && method2 === inf_method2
355-
if hardlimit
356-
topmost = infstate
357-
edgecycle = true
358-
else
359-
# if this is a soft limit,
360-
# also inspect the parent of this edge,
361-
# to see if they are the same Method as sv
362-
# in which case we'll need to ensure it is convergent
363-
# otherwise, we don't
364-
for parent in infstate.callers_in_cycle
365-
# check in the cycle list first
366-
# all items in here are mutual parents of all others
367-
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
368-
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
369-
if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
370-
topmost = infstate
371-
edgecycle = true
372-
break
373-
end
374-
end
375-
let parent = infstate.parent
376-
# then check the parent link
377-
if topmost === nothing && parent !== nothing
378-
parent = parent::InferenceState
379-
parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match
380-
parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing}
381-
if (parent.cached || parent.parent !== nothing) && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
382-
topmost = infstate
383-
edgecycle = true
384-
end
385-
end
386-
end
387-
end
392+
topmost === nothing || continue
393+
if edge_matches_sv(infstate)
394+
topmost = infstate
395+
edgecycle = true
388396
end
389397
end
390-
# iterate through the cycle before walking to the parent
391-
if cyclei < length(infstate.callers_in_cycle)
392-
cyclei += 1
393-
infstate = infstate.callers_in_cycle[cyclei]
394-
else
395-
cyclei = 0
396-
infstate = infstate.parent
397-
end
398398
end
399399

400-
if !(topmost === nothing)
401-
topmost = topmost::InferenceState
400+
if topmost !== nothing
402401
sigtuple = unwrap_unionall(sig)::DataType
403402
msig = unwrap_unionall(method.sig)::DataType
404403
spec_len = length(msig.parameters) + 1
405404
ls = length(sigtuple.parameters)
405+
406406
if method === sv.linfo.def
407407
# Under direct self-recursion, permit much greater use of reducers.
408408
# here we assume that complexity(specTypes) :>= complexity(sig)
@@ -412,6 +412,13 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
412412
else
413413
comparison = method.sig
414414
end
415+
416+
if isdefined(method, :recursion_relation)
417+
# We don't recquire the recursion_relation to be transitive, so
418+
# apply a hard limit
419+
hardlimit = true
420+
end
421+
415422
# see if the type is actually too big (relative to the caller), and limit it if required
416423
newsig = limit_type_size(sig, comparison, hardlimit ? comparison : sv.linfo.specTypes, InferenceParams(interp).TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len)
417424

@@ -427,6 +434,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
427434
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
428435
return Any, true, nothing
429436
end
437+
add_remark!(interp, sv, RECURSION_MSG)
430438
topmost = topmost::InferenceState
431439
parentframe = topmost.parent
432440
poison_callstack(sv, parentframe === nothing ? topmost : parentframe)
@@ -478,24 +486,13 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
478486
inf_cache = get_inference_cache(interp)
479487
inf_result = cache_lookup(mi, argtypes, inf_cache)
480488
if inf_result === nothing
481-
if edgecycle
482-
# if there might be a cycle, check to make sure we don't end up
483-
# calling ourselves here.
484-
infstate = sv
485-
cyclei = 0
486-
while !(infstate === nothing)
487-
if match.method === infstate.linfo.def && any(infstate.result.overridden_by_const)
488-
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
489-
return Any, nothing
490-
end
491-
if cyclei < length(infstate.callers_in_cycle)
492-
cyclei += 1
493-
infstate = infstate.callers_in_cycle[cyclei]
494-
else
495-
cyclei = 0
496-
infstate = infstate.parent
497-
end
489+
# if there might be a cycle, check to make sure we don't end up
490+
# calling ourselves here.
491+
if edgecycle && _any(InfStackUnwind(sv)) do infstate
492+
return match.method === infstate.linfo.def && any(infstate.result.overridden_by_const)
498493
end
494+
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
495+
return Any, nothing
499496
end
500497
inf_result = InferenceResult(mi, argtypes, va_override)
501498
frame = InferenceState(inf_result, #=cache=#false, interp)

base/compiler/inferencestate.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,29 @@ mutable struct InferenceState
124124
end
125125
end
126126

127+
"""
128+
Iterate through all callers of the given InferenceState in the abstract
129+
interpretation stack (including the given InferenceState itself), vising
130+
children before their parents (i.e. ascending the tree from the given
131+
InferenceState). Note that cycles may be visited in any order.
132+
"""
133+
struct InfStackUnwind
134+
inf::InferenceState
135+
end
136+
iterate(unw::InfStackUnwind) = (unw.inf, (unw.inf, 0))
137+
function iterate(unw::InfStackUnwind, (infstate, cyclei)::Tuple{InferenceState, Int})
138+
# iterate through the cycle before walking to the parent
139+
if cyclei < length(infstate.callers_in_cycle)
140+
cyclei += 1
141+
infstate = infstate.callers_in_cycle[cyclei]
142+
else
143+
cyclei = 0
144+
infstate = infstate.parent
145+
end
146+
infstate === nothing && return nothing
147+
(infstate::InferenceState, (infstate, cyclei))
148+
end
149+
127150
method_table(interp::AbstractInterpreter, sv::InferenceState) = sv.method_table
128151

129152
function InferenceState(result::InferenceResult, cached::Bool, interp::AbstractInterpreter)

src/dump.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
660660
jl_serialize_value(s, (jl_value_t*)m->unspecialized);
661661
jl_serialize_value(s, (jl_value_t*)m->generator);
662662
jl_serialize_value(s, (jl_value_t*)m->invokes);
663+
jl_serialize_value(s, (jl_value_t*)m->recursion_relation);
663664
}
664665
else if (jl_is_method_instance(v)) {
665666
jl_method_instance_t *mi = (jl_method_instance_t*)v;
@@ -1503,6 +1504,9 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
15031504
jl_gc_wb(m, m->generator);
15041505
m->invokes = jl_deserialize_value(s, (jl_value_t**)&m->invokes);
15051506
jl_gc_wb(m, m->invokes);
1507+
m->recursion_relation = jl_deserialize_value(s, (jl_value_t**)&m->recursion_relation);
1508+
if (m->recursion_relation)
1509+
jl_gc_wb(m, m->recursion_relation);
15061510
JL_MUTEX_INIT(&m->writelock);
15071511
return (jl_value_t*)m;
15081512
}

src/jltypes.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,7 +2200,7 @@ void jl_init_types(void) JL_GC_DISABLED
22002200
jl_method_type =
22012201
jl_new_datatype(jl_symbol("Method"), core,
22022202
jl_any_type, jl_emptysvec,
2203-
jl_perm_symsvec(24,
2203+
jl_perm_symsvec(25,
22042204
"name",
22052205
"module",
22062206
"file",
@@ -2217,6 +2217,7 @@ void jl_init_types(void) JL_GC_DISABLED
22172217
"roots",
22182218
"ccallable",
22192219
"invokes",
2220+
"recursion_relation",
22202221
"nargs",
22212222
"called",
22222223
"nospecialize",
@@ -2225,7 +2226,7 @@ void jl_init_types(void) JL_GC_DISABLED
22252226
"pure",
22262227
"is_for_opaque_closure",
22272228
"aggressive_constprop"),
2228-
jl_svec(24,
2229+
jl_svec(25,
22292230
jl_symbol_type,
22302231
jl_module_type,
22312232
jl_symbol_type,
@@ -2242,6 +2243,7 @@ void jl_init_types(void) JL_GC_DISABLED
22422243
jl_array_any_type,
22432244
jl_simplevector_type,
22442245
jl_any_type,
2246+
jl_any_type,
22452247
jl_int32_type,
22462248
jl_int32_type,
22472249
jl_int32_type,

src/julia.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,12 @@ typedef struct _jl_method_t {
322322
// the most specific for the argument types.
323323
jl_typemap_t *invokes;
324324

325+
// A function that compares two specializations of this method, returning
326+
// `true` if the first signature is to be considered "smaller" than the
327+
// second for purposes of recursion analysis. Set to NULL to use
328+
// the default recusion relation.
329+
jl_value_t *recursion_relation;
330+
325331
int32_t nargs;
326332
int32_t called; // bit flags: whether each of the first 8 arguments is called
327333
int32_t nospecialize; // bit flags: which arguments should not be specialized

src/method.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
662662
m->nospecialize = module->nospecialize;
663663
m->nkw = 0;
664664
m->invokes = NULL;
665+
m->recursion_relation = NULL;
665666
m->isva = 0;
666667
m->nargs = 0;
667668
m->primary_world = 1;

stdlib/Serialization/src/Serialization.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ const TAGS = Any[
7979

8080
@assert length(TAGS) == 255
8181

82-
const ser_version = 14 # do not make changes without bumping the version #!
82+
const ser_version = 15 # do not make changes without bumping the version #!
8383

8484
format_version(::AbstractSerializer) = ser_version
8585
format_version(s::Serializer) = s.version
@@ -429,6 +429,11 @@ function serialize(s::AbstractSerializer, meth::Method)
429429
else
430430
serialize(s, nothing)
431431
end
432+
if isdefined(meth, :recursion_relation)
433+
serialize(s, method.recursion_relation)
434+
else
435+
serialize(s, nothing)
436+
end
432437
nothing
433438
end
434439

@@ -1007,6 +1012,10 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
10071012
template = template_or_is_opaque
10081013
end
10091014
generator = deserialize(s)
1015+
recursion_relation = nothing
1016+
if format_version(s) >= 15
1017+
recursion_relation = deserialize(s)
1018+
end
10101019
if makenew
10111020
meth.module = mod
10121021
meth.name = name
@@ -1033,6 +1042,9 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
10331042
linfo.def = meth
10341043
meth.generator = linfo
10351044
end
1045+
if recursion_relation !== nothing
1046+
meth.recursion_relation = recursion_relation
1047+
end
10361048
if !is_for_opaque_closure
10371049
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
10381050
if mt !== nothing && nothing === ccall(:jl_methtable_lookup, Any, (Any, Any, UInt), mt, sig, typemax(UInt))

0 commit comments

Comments
 (0)