|
1 | 1 | using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState
|
2 | 2 |
|
| 3 | +if VERSION >= v"1.12.0-DEV.1268" |
| 4 | + |
| 5 | +using Core.Compiler: Future |
| 6 | + |
| 7 | +function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), |
| 8 | + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta}) |
| 9 | + if f === ChainRulesCore.frule |
| 10 | + # TODO: Currently, we don't have any termination analysis for the non-stratified |
| 11 | + # forward analysis, so bail out here. |
| 12 | + return primal_call |
| 13 | + end |
| 14 | + |
| 15 | + nargs = length(arginfo.argtypes)-1 |
| 16 | + frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}] |
| 17 | + frule_argtypes = append!(frule_preargtypes, arginfo.argtypes) |
| 18 | + local frule_atype::Any = CC.argtypes_to_type(frule_argtypes) |
| 19 | + |
| 20 | + local frule_call::Future{CallMeta} |
| 21 | + local result::Future{CallMeta} = Future{CallMeta}() |
| 22 | + function make_progress(_, sv) |
| 23 | + if isa(primal_call[].info, UnionSplitApplyCallInfo) |
| 24 | + result[] = primal_call[] |
| 25 | + return true |
| 26 | + end |
| 27 | + |
| 28 | + ready = false |
| 29 | + if !@isdefined(frule_call) |
| 30 | + # Here we simply check for the frule existance - we don't want to do a full |
| 31 | + # inference with specialized argtypes and everything since the problem is |
| 32 | + # likely sparse and we only need to do a full inference on a few calls. |
| 33 | + # Thus, here we pick `Any` for the tangent types rather than trying to |
| 34 | + # discover what they are. frules should be written in such a way that |
| 35 | + # whether or not they return `nothing`, only depends on the non-tangent arguments |
| 36 | + frule_arginfo = ArgInfo(nothing, frule_argtypes) |
| 37 | + frule_si = StmtInfo(true) |
| 38 | + # turn off frule analysis in the frule to avoid cycling |
| 39 | + interp′ = disable_forward(interp) |
| 40 | + frule_call = CC.abstract_call_gf_by_type(interp′, |
| 41 | + ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)::Future |
| 42 | + isready(frule_call) || return false |
| 43 | + end |
| 44 | + |
| 45 | + frc = frule_call[] |
| 46 | + pc = primal_call[] |
| 47 | + |
| 48 | + if frc.rt !== Const(nothing) |
| 49 | + result[] = CallMeta(pc.rt, pc.exct, pc.effects, FRuleCallInfo(pc.info, frc)) |
| 50 | + else |
| 51 | + result[] = pc |
| 52 | + CC.add_mt_backedge!(sv, frule_mt, frule_atype) |
| 53 | + end |
| 54 | + |
| 55 | + return true |
| 56 | + end |
| 57 | + (!isready(primal_call) || !make_progress(interp, sv)) && push!(sv.tasks, make_progress) |
| 58 | + return result |
| 59 | +end |
| 60 | + |
| 61 | +else |
| 62 | + |
3 | 63 | function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
|
4 | 64 | arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::CallMeta)
|
5 | 65 | if f === ChainRulesCore.frule
|
6 | 66 | # TODO: Currently, we don't have any termination analysis for the non-stratified
|
7 | 67 | # forward analysis, so bail out here.
|
8 |
| - return nothing |
| 68 | + return primal_call |
9 | 69 | end
|
10 | 70 |
|
11 | 71 | nargs = length(arginfo.argtypes)-1
|
@@ -35,7 +95,11 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
|
35 | 95 | CC.add_mt_backedge!(sv, frule_mt, frule_atype)
|
36 | 96 | end
|
37 | 97 |
|
38 |
| - return nothing |
| 98 | + return primal_call |
| 99 | +end |
| 100 | + |
| 101 | + |
| 102 | + |
39 | 103 | end
|
40 | 104 |
|
41 | 105 | const frule_mt = methods(ChainRulesCore.frule).mt
|
0 commit comments