Skip to content

Commit b293bbd

Browse files
authored
Hookup demand-driven forward mode to the Diffractor runtime (#99)
* Hookup demand-driven forward mode to the Diffractor runtime Tests currently depend on JuliaLang/julia#48045 and JuliaLang/julia#48059, so we should either get those merged first, or mark them here as broken. * Mark test as broken
1 parent b53d71b commit b293bbd

File tree

7 files changed

+244
-54
lines changed

7 files changed

+244
-54
lines changed

src/codegen/forward_demand.jl

Lines changed: 119 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
using Core.Compiler: IRInterpretationState, construct_postdomtree, PiNode,
22
is_known_call, argextype, postdominates
33

4-
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, pantelides::Vector{SSAValue}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
4+
#=
5+
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff::Vector{Pair{SSAValue, Int}}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
56
Δs = SSAValue[]
67
rets = findall(@nospecialize(x)->isa(x, ReturnNode) && isdefined(x, :val), ir.stmts.inst)
78
postdomtree = construct_postdomtree(ir.cfg.blocks)
8-
for ssa in pantelides
9-
Δssa = forward_diff!(ir, interp, irsv, ssa; custom_diff!, diff_cache)
9+
for (ssa, order) in to_diff
10+
Δssa = forward_diff!(ir, interp, irsv, ssa, order; custom_diff!, diff_cache)
1011
Δblock = block_for_inst(ir, Δssa.id)
1112
for idx in rets
1213
retblock = block_for_inst(ir, idx)
@@ -18,31 +19,24 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, pantelid
1819
end
1920
return (ir, Δs)
2021
end
22+
=#
2123

22-
function diff_unassigned_variable!(ir, ssa)
23-
return insert_node!(ir, ssa, NewInstruction(
24-
Expr(:call, GlobalRef(Intrinsics, :state_ddt), ssa), Float64), #=attach_after=#true)
25-
end
26-
27-
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue; custom_diff!, diff_cache)
24+
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, order::Int; custom_diff!, diff_cache)
2825
if haskey(diff_cache, ssa)
2926
return diff_cache[ssa]
3027
end
3128
inst = ir[ssa]
3229
stmt = inst[:inst]
33-
if isa(stmt, SSAValue)
34-
return forward_diff!(ir, interp, irsv, stmt; custom_diff!, diff_cache)
35-
end
36-
Δssa = forward_diff_uncached!(ir, interp, irsv, ssa, inst; custom_diff!, diff_cache)
30+
Δssa = forward_diff_uncached!(ir, interp, irsv, ssa, inst, order::Int; custom_diff!, diff_cache)
3731
@assert Δssa !== nothing
3832
if isa(Δssa, SSAValue)
3933
diff_cache[ssa] = Δssa
4034
end
4135
return Δssa
4236
end
43-
forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, val::Union{Integer, AbstractFloat}; custom_diff!, diff_cache) = zero(val)
44-
forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, @nospecialize(arg); custom_diff!, diff_cache) = ChainRulesCore.NoTangent()
45-
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Argument; custom_diff!, diff_cache)
37+
forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, val::Union{Integer, AbstractFloat}, order::Int; custom_diff!, diff_cache) = zero(val)
38+
forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, @nospecialize(arg), order::Int; custom_diff!, diff_cache) = ChainRulesCore.NoTangent()
39+
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Argument, order::Int; custom_diff!, diff_cache)
4640
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache)
4741
val = custom_diff!(ir, SSAValue(0), arg, recurse)
4842
if val !== nothing
@@ -51,13 +45,15 @@ function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, arg::Arg
5145
return ChainRulesCore.NoTangent()
5246
end
5347

54-
function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, inst::Core.Compiler.Instruction; custom_diff!, diff_cache)
48+
function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState, ssa::SSAValue, inst::Core.Compiler.Instruction, order::Int; custom_diff!, diff_cache)
5549
stmt = inst[:inst]
56-
recurse(x) = forward_diff!(ir, interp, irsv, x; custom_diff!, diff_cache)
50+
recurse(x) = forward_diff!(ir, interp, irsv, x, order; custom_diff!, diff_cache)
5751
if (val = custom_diff!(ir, ssa, stmt, recurse)) !== nothing
5852
return val
5953
elseif isa(stmt, PiNode)
6054
return recurse(stmt.val)
55+
elseif isa(stmt, SSAValue)
56+
return recurse(stmt)
6157
elseif isa(stmt, PhiNode)
6258
Δphi = PhiNode(copy(stmt.edges), similar(stmt.values))
6359
T = Union{}
@@ -152,3 +148,108 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
152148
return Δssa
153149
end
154150
end
151+
152+
function forward_visit!(ir::IRCode, ssa::SSAValue, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!)
153+
if ssa_orders[ssa.id][1] >= order
154+
return
155+
end
156+
ssa_orders[ssa.id] = order => ssa_orders[ssa.id][2]
157+
inst = ir[ssa]
158+
stmt = inst[:inst]
159+
recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!)
160+
if visit_custom!(ir, stmt, order, recurse)
161+
ssa_orders[ssa.id] = order => true
162+
return
163+
elseif isa(stmt, PiNode)
164+
return recurse(stmt.val)
165+
elseif isa(stmt, PhiNode)
166+
for i = 1:length(stmt.values)
167+
isassigned(stmt.values, i) || continue
168+
recurse(stmt.values[i])
169+
end
170+
return
171+
elseif isexpr(stmt, :new) || isexpr(stmt, :invoke)
172+
foreach(recurse, stmt.args[2:end])
173+
elseif isexpr(stmt, :call)
174+
foreach(recurse, stmt.args)
175+
elseif isa(stmt, SSAValue)
176+
recurse(stmt)
177+
elseif !isa(stmt, Expr)
178+
return
179+
else
180+
@show stmt
181+
error()
182+
end
183+
end
184+
forward_visit!(ir::IRCode, _, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!) = nothing
185+
function forward_visit!(ir::IRCode, a::Argument, order::Int, ssa_orders::Vector{Pair{Int, Bool}}, visit_custom!)
186+
recurse(@nospecialize(val)) = forward_visit!(ir, val, order, ssa_orders, visit_custom!)
187+
return visit_custom!(ir, a, order, recurse)
188+
end
189+
190+
191+
function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::Vector{Pair{SSAValue, Int}};
192+
visit_custom! = (args...)->false, transform! = (args...)->error())
193+
# Step 1: For each SSAValue in the IR, keep track of the differentiation order needed
194+
ssa_orders = [0=>false for i = 1:length(ir.stmts)]
195+
for (ssa, order) in to_diff
196+
forward_visit!(ir, ssa, order, ssa_orders, visit_custom!)
197+
end
198+
199+
# Step 2: Transform
200+
function maparg(arg, ssa, order)
201+
if isa(arg, Argument)
202+
# TODO: Should we remember whether the callbacks wanted the arg?
203+
return transform!(ir, arg, order)
204+
elseif isa(arg, SSAValue)
205+
# TODO: Bundle truncation if necessary
206+
return arg
207+
end
208+
@assert !isa(arg, Expr)
209+
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any))
210+
end
211+
212+
for (ssa, (order, custom)) in enumerate(ssa_orders)
213+
if order == 0
214+
# TODO: Bundle truncation?
215+
continue
216+
end
217+
if custom
218+
transform!(ir, SSAValue(ssa), order)
219+
else
220+
inst = ir[SSAValue(ssa)]
221+
stmt = inst[:inst]
222+
if isexpr(stmt, :invoke)
223+
inst[:inst] = Expr(:call, ∂☆{order}(), map(arg->maparg(arg, SSAValue(ssa), order), stmt.args[2:end])...)
224+
inst[:type] = Any
225+
elseif !isa(stmt, Expr)
226+
inst[:inst] = maparg(stmt, ssa, order)
227+
inst[:type] = Any
228+
else
229+
@show stmt
230+
error()
231+
end
232+
end
233+
end
234+
235+
end
236+
237+
function forward_diff!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::Vector{Pair{SSAValue, Int}}; kwargs...)
238+
forward_diff_no_inf!(ir, interp, mi, world, to_diff; kwargs...)
239+
240+
# Step 3: Re-inference
241+
ir = compact!(ir)
242+
243+
extra_reprocess = CC.BitSet()
244+
for i = 1:length(ir.stmts)
245+
if ir[SSAValue(i)][:type] == Any
246+
CC.push!(extra_reprocess, i)
247+
end
248+
end
249+
250+
interp′ = enable_reinference(interp)
251+
irsv = IRInterpretationState(interp′, ir, mi, world, ir.argtypes[1:mi.def.nargs])
252+
rt = CC._ir_abstract_constant_propagation(interp′, irsv; extra_reprocess)
253+
254+
return ir
255+
end

src/higher_fwd_rules.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,23 @@ for f in (sin, cos, exp)
2828
end
2929
end
3030

31+
# TODO: It's a bit embarassing that we need to write these out, but currently the
32+
# compiler is not strong enough to automatically lift the frule. Let's hope we
33+
# can delete these in the near future.
34+
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
35+
TaylorBundle{N}(primal(a) + primal(b),
36+
map(+, a.tangent.coeffs, b.tangent.coeffs))
37+
end
38+
39+
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::ZeroBundle{N}) where {N}
40+
TaylorBundle{N}(primal(a) + primal(b), a.tangent.coeffs)
41+
end
42+
43+
function (∂☆ₙ::∂☆{N})(fb::ZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N}
44+
TaylorBundle{N}(primal(a) - primal(b),
45+
map(-, a.tangent.coeffs, b.tangent.coeffs))
46+
end
47+
3148
function (::Diffractor.∂☆new{N})(B::ATB{N, Type{T}}, args::ATB{N}...) where {N, T<:SArray}
3249
error("Should have intercepted the constructor")
3350
end

src/stage1/forward.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,15 @@ struct FwdIterate{N, T<:AbstractTangentBundle{N}}
205205
end
206206
function (f::FwdIterate)(arg::ATB{N}) where {N}
207207
r = ∂☆{N}()(f.f, arg)
208-
primal(r) === nothing && return nothing
208+
# `primal(r) === nothing` would work, but doesn't create `Conditional` in inference
209+
isa(r, ATB{N, Nothing}) && return nothing
209210
(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(1)),
210211
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
211212
end
212-
function (f::FwdIterate)(arg::ATB{N}, st) where {N}
213+
@Base.constprop :aggressive function (f::FwdIterate)(arg::ATB{N}, st) where {N}
213214
r = ∂☆{N}()(f.f, arg, ZeroBundle{N}(st))
214-
primal(r) === nothing && return nothing
215+
# `primal(r) === nothing` would work, but doesn't create `Conditional` in inference
216+
isa(r, ATB{N, Nothing}) && return nothing
215217
(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(1)),
216218
primal(∂☆{N}()(ZeroBundle{N}(getindex), r, ZeroBundle{N}(2))))
217219
end

src/stage2/forward.jl

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,50 @@ using .CC: compact!
22

33
# Engineering entry point for the 2nd-order forward AD functionality. This is
44
# unlikely to be the actual interface. For now, it is used for testing.
5-
function dontuse_nth_order_forward_stage2(tt::Type)
5+
function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
66
interp = ADInterpreter(; forward=true, backward=false)
77
match = Base._which(tt)
88
frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true)
99

1010
ir = copy((interp.opt[0][frame.linfo].inferred).ir::IRCode)
1111

1212
# Find all Return Nodes
13-
vals = SSAValue[]
13+
vals = Pair{SSAValue, Int}[]
1414
for i = 1:length(ir.stmts)
1515
if isa(ir[SSAValue(i)][:inst], ReturnNode)
16-
push!(vals, SSAValue(i))
16+
push!(vals, SSAValue(i)=>order)
1717
end
1818
end
1919

20-
function custom_diff!(ir, ssa, stmt, recurse)
20+
function visit_custom!(ir::IRCode, @nospecialize(stmt), order, recurse)
2121
if isa(stmt, ReturnNode)
22-
r = recurse(stmt.val)
23-
ir[ssa][:inst] = ReturnNode(r)
24-
return ssa
22+
recurse(stmt.val)
23+
return true
2524
elseif isa(stmt, Argument)
26-
return 1.0
25+
return true
26+
else
27+
return false
2728
end
28-
return nothing
2929
end
3030

31+
function transform!(ir::IRCode, ssa::SSAValue, _)
32+
inst = ir[ssa]
33+
stmt = inst[:inst]
34+
if isa(stmt, ReturnNode)
35+
nr = insert_node!(ir, ssa, NewInstruction(Expr(:call, getindex, stmt.val, TaylorTangentIndex(order)), Any))
36+
inst[:inst] = ReturnNode(nr)
37+
else
38+
error()
39+
end
40+
end
41+
42+
function transform!(ir::IRCode, arg::Argument, _)
43+
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0))))
44+
end
45+
46+
3147
irsv = CC.IRInterpretationState(interp, ir, frame.linfo, CC.get_world_counter(interp), ir.argtypes[1:frame.linfo.def.nargs])
32-
forward_diff!(ir, interp, irsv, vals; custom_diff!)
48+
ir = forward_diff!(ir, interp, frame.linfo, CC.get_world_counter(interp), vals; visit_custom!, transform!)
3349

3450
ir = compact!(ir)
3551
return OpaqueClosure(ir)

0 commit comments

Comments
 (0)