Skip to content

Commit 1cbde03

Browse files
authored
Adjust forward stage2 to Core.Compiler changes (#295)
Only what is necessary for Cedar right now. Ordinary stage 2 reverse mode will need similar changes at a later point.
1 parent 778af00 commit 1cbde03

File tree

5 files changed

+80
-6
lines changed

5 files changed

+80
-6
lines changed

src/analysis/forward.jl

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,71 @@
11
using Core.Compiler: StmtInfo, ArgInfo, CallMeta, AbsIntState
22

3+
if VERSION >= v"1.12.0-DEV.1268"
4+
5+
using Core.Compiler: Future
6+
7+
function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
8+
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::Future{CallMeta})
9+
if f === ChainRulesCore.frule
10+
# TODO: Currently, we don't have any termination analysis for the non-stratified
11+
# forward analysis, so bail out here.
12+
return primal_call
13+
end
14+
15+
nargs = length(arginfo.argtypes)-1
16+
frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}]
17+
frule_argtypes = append!(frule_preargtypes, arginfo.argtypes)
18+
local frule_atype::Any = CC.argtypes_to_type(frule_argtypes)
19+
20+
local frule_call::Future{CallMeta}
21+
local result::Future{CallMeta} = Future{CallMeta}()
22+
function make_progress(_, sv)
23+
if isa(primal_call[].info, UnionSplitApplyCallInfo)
24+
result[] = primal_call[]
25+
return true
26+
end
27+
28+
ready = false
29+
if !@isdefined(frule_call)
30+
# Here we simply check for the frule existance - we don't want to do a full
31+
# inference with specialized argtypes and everything since the problem is
32+
# likely sparse and we only need to do a full inference on a few calls.
33+
# Thus, here we pick `Any` for the tangent types rather than trying to
34+
# discover what they are. frules should be written in such a way that
35+
# whether or not they return `nothing`, only depends on the non-tangent arguments
36+
frule_arginfo = ArgInfo(nothing, frule_argtypes)
37+
frule_si = StmtInfo(true)
38+
# turn off frule analysis in the frule to avoid cycling
39+
interp′ = disable_forward(interp)
40+
frule_call = CC.abstract_call_gf_by_type(interp′,
41+
ChainRulesCore.frule, frule_arginfo, frule_si, frule_atype, sv, #=max_methods=#-1)::Future
42+
isready(frule_call) || return false
43+
end
44+
45+
frc = frule_call[]
46+
pc = primal_call[]
47+
48+
if frc.rt !== Const(nothing)
49+
result[] = CallMeta(pc.rt, pc.exct, pc.effects, FRuleCallInfo(pc.info, frc))
50+
else
51+
result[] = pc
52+
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
53+
end
54+
55+
return true
56+
end
57+
(!isready(primal_call) || !make_progress(interp, sv)) && push!(sv.tasks, make_progress)
58+
return result
59+
end
60+
61+
else
62+
363
function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
464
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, primal_call::CallMeta)
565
if f === ChainRulesCore.frule
666
# TODO: Currently, we don't have any termination analysis for the non-stratified
767
# forward analysis, so bail out here.
8-
return nothing
68+
return primal_call
969
end
1070

1171
nargs = length(arginfo.argtypes)-1
@@ -35,7 +95,11 @@ function fwd_abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize
3595
CC.add_mt_backedge!(sv, frule_mt, frule_atype)
3696
end
3797

38-
return nothing
98+
return primal_call
99+
end
100+
101+
102+
39103
end
40104

41105
const frule_mt = methods(ChainRulesCore.frule).mt

src/stage1/compiler_utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ if VERSION < v"1.11.0-DEV.258"
1111
Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa)
1212
end
1313

14+
if isdefined(CC, :Future)
15+
Base.isready(future::CC.Future) = CC.isready(future)
16+
Base.getindex(future::CC.Future) = CC.getindex(future)
17+
Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value)
18+
end
19+
1420
Base.copy(ir::IRCode) = CC.copy(ir)
1521

1622
CC.NewInstruction(@nospecialize node) =

src/stage1/recurse_fwd.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
209209
ci.ssaflags = UInt8[0 for i=1:length(new_code)]
210210
ci.method_for_inference_limit_heuristics = meth
211211
ci.edges = MethodInstance[mi]
212+
if hasfield(CodeInfo, :nargs)
213+
ci.nargs = 2
214+
ci.isva = true
215+
end
212216

213217
if isdefined(Base, :__has_internal_change) && Base.__has_internal_change(v"1.12-alpha", :codeinfonargs)
214218
ci.nargs = 2

src/stage2/abstractinterpret.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,7 @@ function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
7474
arginfo::ArgInfo, si::StmtInfo, atype::Any, sv::InferenceState, max_methods::Int)
7575

7676
if interp.forward
77-
r = fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret)
78-
if r !== nothing
79-
return r
80-
end
77+
return fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret)
8178
end
8279

8380
return ret

src/stage2/lattice.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ end
7171
CC.nsplit_impl(info::FRuleCallInfo) = CC.nsplit(info.info)
7272
CC.getsplit_impl(info::FRuleCallInfo, idx::Int) = CC.getsplit(info.info, idx)
7373
CC.getresult_impl(info::FRuleCallInfo, idx::Int) = CC.getresult(info.info, idx)
74+
if isdefined(CC, :add_uncovered_edges_impl)
75+
CC.add_uncovered_edges_impl(edges::Vector{Any}, info::FRuleCallInfo, @nospecialize(atype)) = CC.add_uncovered_edges!(edges, info.info, atype)
76+
end
7477

7578
function Base.show(io::IO, info::FRuleCallInfo)
7679
print(io, "FRuleCallInfo(", typeof(info.info), ", ", typeof(info.frule_call.info), ")")

0 commit comments

Comments
 (0)