Skip to content

Commit 73622ed

Browse files
committed
inference: Add a hook for users to be able to specify custom recursion relations
When infering recursive functions, we try to detect cases where the recursion can be shown to be terminating and in such cases allow significantly more inference among the recursive functions than in cases where we cannot prove termination. The rationale for this is to allow inference of functions that are recursive over structures, while not spending useless (or infinite) compile time chasing down recursions that build up infinitely large types (those get expensive *fast*). Our built-in recursion relation here allows recursion to proceed if argument types are syntactic subsets, if they are decreasing integers, tuples of decreasing length, and a few other cases. It is worth noting that similar considerations come in in significantly more static system, where non-termination is (statically) disallowed. E.g. in Coq, for non-syntactic recursion, a proof needs to be provided that a recursive function does indeed terminate, before one is allowed to define it [1]. More modern languages like Dafny allow the user to specify a predicate over the incoming values that is shown to decrease over subsequent function calls [2]. My motivation comes from Diffractor, where we get call chains like: ∂⃖{1}(sin'', ...) -> ∂⃖{2}(sin', ...) -> ∂⃖{3}(sin, ...) -> ∂⃖{2}(rrule, sin, ...) -> ∂⃖{1}(rrule, sin) In this example, the first two calls are both the same method, as are the last two. Unfortunately, particularly for the first two calls, there isn't really a good way to express the recursion rule here in a way that is generic. Thus, to address these cases, this PR adds a per-method field, similar to a generator that allows packages to provide arbitrary recursion relations that take advantage of the (known) special semantics of those methods to expand the allowed set of recursions. Originally I had hoped to use this hook in place of the existing `type_more_complex` check. However, our code currently requires transitivity of the `type_more_complex` check for soundness of the termination analysis. This runs into problems in the specified use case, because we may have interleaved chains of calls, that are both the same method, but are not actually part of a cycle as such because their ultimate underlying methods are different (in particular this happens when chaining two-Cassette like generated functions). We do not currently express enough about the semantics of these Cassette-like methods in order for inference to reasonably compute whether two instances are part of the same cycle or not (we have `method_for_inference_heuristics` of course, which takes care of one level of this, but does not take care of the nested case). By having this hook, but not requring transitivity, it is legal for the hook to compute whether the ultimate underlying method is the same (by using its knowledge of what the methods actually do) and answering accordingly. In the long run, I would like to bring these Cassette-like capabilities more closely into the compiler, at which point inference itself may have enough information to compute the cycles and we'd be able to get away with requiring transitivity. All that said, this mechanism is quite simple and achieves its goal. I don't think it is particularly pretty and should definitely be considered unstable. I'm not providing any user-facing APIs for this, so those in the know will have to manually poke the methods. I do think a more general language-level framework for proving termination could be useful, particularly as part of more rigurous definitions of when various constant propagation happens, but this is not that, yet. I've tried this in Diffractor and with appropriate definitions of the recursion relation for the relevant functions, Diffractor becomes nicely inferable: ``` julia> using Diffractor: var"'", ∂⃖ julia> Base.return_types(sin''', Tuple{Float64}) 1-element Vector{Any}: Float64 julia> Base.return_types(sin'''', Tuple{Float64}) 1-element Vector{Any}: Float64 julia> Base.return_types(sin''''', Tuple{Float64}) 1-element Vector{Any}: Float64 julia> Base.return_types(sin'''''', Tuple{Float64}) 1-element Vector{Any}: Float64 ``` Diffractor's Phase 1 design goal was to infer fine at 3rd and 4th order - which this meets. The fact that it also infers at higher orders is nice, but inference times also increase to impractical levels for real-world functions, e.g. 5th order above takes a few seconds to infer just `sin` and 6th order takes about 20s or so. Of course that is still better than Zyogte, even at second order, but fixing this properly will be part of Diffractor Phase 2. With this (plus some additional tweaks to constprop heuristics for OpaqueClosure that I'll be putting up separately), we do also generate very nice code: ``` julia> @code_typed sin'''(1.0) CodeInfo( 1 ─ %1 = invoke ChainRules.sincos(_2::Float64)::Tuple{Float64, Float64} │ %2 = Base.getfield(%1, 1)::Float64 │ %3 = Base.getfield(%1, 2)::Float64 │ %4 = Diffractor.getfield(%1, 1)::Float64 │ %5 = Diffractor.getfield(%1, 2)::Float64 │ %6 = Diffractor.getfield(%1, 2)::Float64 │ %7 = Base.mul_float(%6, 1.0)::Float64 │ %8 = Base.mul_float(0.0, %7)::Float64 │ %9 = Base.neg_float(%4)::Float64 │ %10 = Base.mul_float(%9, 1.0)::Float64 │ %11 = Base.mul_float(1.0, %8)::Float64 │ %12 = Base.mul_float(%5, 1.0)::Float64 │ %13 = Base.mul_float(%12, %7)::Float64 │ %14 = Base.mul_float(0.0, %12)::Float64 │ %15 = Base.mul_float(0.0, %13)::Float64 │ %16 = Base.mul_float(%14, 1.0)::Float64 │ %17 = Base.mul_float(%6, %14)::Float64 │ %18 = Base.mul_float(%10, 1.0)::Float64 │ %19 = Base.mul_float(1.0, %10)::Float64 │ %20 = Base.add_float(%17, %18)::Float64 │ %21 = Base.mul_float(0.0, %20)::Float64 │ %22 = Base.mul_float(%21, 1.0)::Float64 │ %23 = Base.mul_float(%6, %21)::Float64 │ %24 = Base.add_float(%22, %16)::Float64 │ %25 = Base.add_float(%23, %19)::Float64 │ %26 = Base.mul_float(0.0, %25)::Float64 │ %27 = Base.add_float(%15, %26)::Float64 │ %28 = Base.add_float(%24, %11)::Float64 │ %29 = Base.add_float(%27, -1.0)::Float64 │ %30 = Base.neg_float(%2)::Float64 │ %31 = Base.mul_float(%3, %29)::Float64 │ %32 = Base.muladd_float(%30, %28, %31)::Float64 └── return %32 ) => Float64 ``` [1] http://adam.chlipala.net/cpdt/html/GeneralRec.html [2] https://rise4fun.com/Dafny/tutorial/Termination
1 parent 998951e commit 73622ed

File tree

6 files changed

+73
-27
lines changed

6 files changed

+73
-27
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,41 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
337337
return parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2
338338
end
339339

340+
function edge_matches_sv(frame::InferenceState)
341+
inf_method2 = frame.src.method_for_inference_limit_heuristics # limit only if user token match
342+
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
343+
if callee_method2 !== inf_method2
344+
return false
345+
end
346+
if !hardlimit
347+
# if this is a soft limit,
348+
# also inspect the parent of this edge,
349+
# to see if they are the same Method as sv
350+
# in which case we'll need to ensure it is convergent
351+
# otherwise, we don't
352+
353+
# check in the cycle list first
354+
# all items in here are mutual parents of all others
355+
if !_any(matches_sv, frame.callers_in_cycle)
356+
let parent = frame.parent
357+
parent !== nothing || return false
358+
parent = parent::InferenceState
359+
(parent.cached || parent.parent !== nothing) || return false
360+
matches_sv(parent) || return false
361+
end
362+
end
363+
364+
# If the method defines a recursion relation, give it a chance
365+
# to tell us that this recursion is actually ok.
366+
if isdefined(method, :recursion_relation)
367+
if Core._apply_pure(method.recursion_relation, Any[method, callee_method2, sig, frame.linfo.specTypes])
368+
return false
369+
end
370+
end
371+
end
372+
return true
373+
end
374+
340375
for infstate in InfStackUnwind(sv)
341376
if method === infstate.linfo.def
342377
if infstate.linfo.specTypes == sig
@@ -355,40 +390,19 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
355390
break
356391
end
357392
topmost === nothing || continue
358-
inf_method2 = infstate.src.method_for_inference_limit_heuristics # limit only if user token match
359-
inf_method2 isa Method || (inf_method2 = nothing) # Union{Method, Nothing}
360-
if callee_method2 === inf_method2
361-
if !hardlimit
362-
# if this is a soft limit,
363-
# also inspect the parent of this edge,
364-
# to see if they are the same Method as sv
365-
# in which case we'll need to ensure it is convergent
366-
# otherwise, we don't
367-
368-
# check in the cycle list first
369-
# all items in here are mutual parents of all others
370-
if !_any(matches_sv, infstate.callers_in_cycle)
371-
let parent = infstate.parent
372-
parent !== nothing || continue
373-
parent = parent::InferenceState
374-
(parent.cached || parent.parent !== nothing) || continue
375-
matches_sv(parent) || continue
376-
end
377-
end
378-
end
379-
393+
if edge_matches_sv(infstate)
380394
topmost = infstate
381395
edgecycle = true
382396
end
383397
end
384398
end
385399

386-
if !(topmost === nothing)
387-
topmost = topmost::InferenceState
400+
if topmost !== nothing
388401
sigtuple = unwrap_unionall(sig)::DataType
389402
msig = unwrap_unionall(method.sig)::DataType
390403
spec_len = length(msig.parameters) + 1
391404
ls = length(sigtuple.parameters)
405+
392406
if method === sv.linfo.def
393407
# Under direct self-recursion, permit much greater use of reducers.
394408
# here we assume that complexity(specTypes) :>= complexity(sig)
@@ -398,6 +412,13 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
398412
else
399413
comparison = method.sig
400414
end
415+
416+
if isdefined(method, :recursion_relation)
417+
# We don't recquire the recursion_relation to be transitive, so
418+
# apply a hard limit
419+
hardlimit = true
420+
end
421+
401422
# see if the type is actually too big (relative to the caller), and limit it if required
402423
newsig = limit_type_size(sig, comparison, hardlimit ? comparison : sv.linfo.specTypes, InferenceParams(interp).TUPLE_COMPLEXITY_LIMIT_DEPTH, spec_len)
403424

src/dump.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
660660
jl_serialize_value(s, (jl_value_t*)m->unspecialized);
661661
jl_serialize_value(s, (jl_value_t*)m->generator);
662662
jl_serialize_value(s, (jl_value_t*)m->invokes);
663+
jl_serialize_value(s, (jl_value_t*)m->recursion_relation);
663664
}
664665
else if (jl_is_method_instance(v)) {
665666
jl_method_instance_t *mi = (jl_method_instance_t*)v;
@@ -1503,6 +1504,9 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
15031504
jl_gc_wb(m, m->generator);
15041505
m->invokes = jl_deserialize_value(s, (jl_value_t**)&m->invokes);
15051506
jl_gc_wb(m, m->invokes);
1507+
m->recursion_relation = jl_deserialize_value(s, (jl_value_t**)&m->recursion_relation);
1508+
if (m->recursion_relation)
1509+
jl_gc_wb(m, m->recursion_relation);
15061510
JL_MUTEX_INIT(&m->writelock);
15071511
return (jl_value_t*)m;
15081512
}

src/jltypes.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,7 +2200,7 @@ void jl_init_types(void) JL_GC_DISABLED
22002200
jl_method_type =
22012201
jl_new_datatype(jl_symbol("Method"), core,
22022202
jl_any_type, jl_emptysvec,
2203-
jl_perm_symsvec(24,
2203+
jl_perm_symsvec(25,
22042204
"name",
22052205
"module",
22062206
"file",
@@ -2217,6 +2217,7 @@ void jl_init_types(void) JL_GC_DISABLED
22172217
"roots",
22182218
"ccallable",
22192219
"invokes",
2220+
"recursion_relation",
22202221
"nargs",
22212222
"called",
22222223
"nospecialize",
@@ -2225,7 +2226,7 @@ void jl_init_types(void) JL_GC_DISABLED
22252226
"pure",
22262227
"is_for_opaque_closure",
22272228
"aggressive_constprop"),
2228-
jl_svec(24,
2229+
jl_svec(25,
22292230
jl_symbol_type,
22302231
jl_module_type,
22312232
jl_symbol_type,
@@ -2242,6 +2243,7 @@ void jl_init_types(void) JL_GC_DISABLED
22422243
jl_array_any_type,
22432244
jl_simplevector_type,
22442245
jl_any_type,
2246+
jl_any_type,
22452247
jl_int32_type,
22462248
jl_int32_type,
22472249
jl_int32_type,

src/julia.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,12 @@ typedef struct _jl_method_t {
322322
// the most specific for the argument types.
323323
jl_typemap_t *invokes;
324324

325+
// A function that compares two specializations of this method, returning
326+
// `true` if the first signature is to be considered "smaller" than the
327+
// second for purposes of recursion analysis. Set to NULL to use
328+
// the default recusion relation.
329+
jl_value_t *recursion_relation;
330+
325331
int32_t nargs;
326332
int32_t called; // bit flags: whether each of the first 8 arguments is called
327333
int32_t nospecialize; // bit flags: which arguments should not be specialized

src/method.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
662662
m->nospecialize = module->nospecialize;
663663
m->nkw = 0;
664664
m->invokes = NULL;
665+
m->recursion_relation = NULL;
665666
m->isva = 0;
666667
m->nargs = 0;
667668
m->primary_world = 1;

stdlib/Serialization/src/Serialization.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ const TAGS = Any[
7979

8080
@assert length(TAGS) == 255
8181

82-
const ser_version = 14 # do not make changes without bumping the version #!
82+
const ser_version = 15 # do not make changes without bumping the version #!
8383

8484
format_version(::AbstractSerializer) = ser_version
8585
format_version(s::Serializer) = s.version
@@ -429,6 +429,11 @@ function serialize(s::AbstractSerializer, meth::Method)
429429
else
430430
serialize(s, nothing)
431431
end
432+
if isdefined(meth, :recursion_relation)
433+
serialize(s, method.recursion_relation)
434+
else
435+
serialize(s, nothing)
436+
end
432437
nothing
433438
end
434439

@@ -1007,6 +1012,10 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
10071012
template = template_or_is_opaque
10081013
end
10091014
generator = deserialize(s)
1015+
recursion_relation = nothing
1016+
if format_version(s) >= 15
1017+
recursion_relation = deserialize(s)
1018+
end
10101019
if makenew
10111020
meth.module = mod
10121021
meth.name = name
@@ -1033,6 +1042,9 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
10331042
linfo.def = meth
10341043
meth.generator = linfo
10351044
end
1045+
if recursion_relation !== nothing
1046+
meth.recursion_relation = recursion_relation
1047+
end
10361048
if !is_for_opaque_closure
10371049
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
10381050
if mt !== nothing && nothing === ccall(:jl_methtable_lookup, Any, (Any, Any, UInt), mt, sig, typemax(UInt))

0 commit comments

Comments
 (0)