Skip to content

Commit 7356cba

Browse files
authored
forward_demand: Implement bundle truncation (#101)
Also make the canonical 1-bundle use the TaylorBundle type. The types are isomorphic, but since `∂xⁿ{1}()` returns a TaylorBundle, this helps type stability.
1 parent b293bbd commit 7356cba

File tree

4 files changed

+52
-19
lines changed

4 files changed

+52
-19
lines changed

src/codegen/forward_demand.jl

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Core.Compiler: IRInterpretationState, construct_postdomtree, PiNode,
2-
is_known_call, argextype, postdominates
2+
is_known_call, argextype, postdominates, userefs
33

44
#=
55
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff::Vector{Pair{SSAValue, Int}}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
@@ -93,12 +93,6 @@ function forward_diff_uncached!(ir::IRCode, interp, irsv::IRInterpretationState,
9393
return Δtangent
9494
else # general frule handling
9595
info = inst[:info]
96-
if !isa(info, FRuleCallInfo)
97-
@show info
98-
@show inst[:inst]
99-
display(ir)
100-
error()
101-
end
10296
if isexpr(stmt, :invoke)
10397
args = stmt.args[2:end]
10498
else
@@ -196,22 +190,50 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
196190
forward_visit!(ir, ssa, order, ssa_orders, visit_custom!)
197191
end
198192

193+
truncation_map = Dict{Pair{SSAValue, Int}, SSAValue}()
194+
199195
# Step 2: Transform
200196
function maparg(arg, ssa, order)
201-
if isa(arg, Argument)
197+
if isa(arg, SSAValue)
198+
if arg.id > length(ssa_orders)
199+
# This is possible if the custom transform touched another statement.
200+
# In that case just pass this through and assume the `transform!` did
201+
# it correctly.
202+
return arg
203+
end
204+
(argorder, _) = ssa_orders[arg.id]
205+
if argorder != order
206+
@assert order < argorder
207+
return get!(truncation_map, arg=>order) do
208+
# TODO: Other orders
209+
@assert order == 0
210+
insert_node!(ir, arg, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true)
211+
end
212+
end
213+
return arg
214+
elseif order == 0
215+
return arg
216+
elseif isa(arg, Argument)
202217
# TODO: Should we remember whether the callbacks wanted the arg?
203218
return transform!(ir, arg, order)
204-
elseif isa(arg, SSAValue)
205-
# TODO: Bundle truncation if necessary
206-
return arg
219+
elseif isa(arg, GlobalRef)
220+
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any))
221+
elseif isa(arg, QuoteNode)
222+
return ZeroBundle{order}(arg.value)
207223
end
208224
@assert !isa(arg, Expr)
209-
return insert_node!(ir, ssa, NewInstruction(Expr(:call, ZeroBundle{order}, arg), Any))
225+
return ZeroBundle{order}(arg)
210226
end
211227

212228
for (ssa, (order, custom)) in enumerate(ssa_orders)
213229
if order == 0
214-
# TODO: Bundle truncation?
230+
inst = ir[SSAValue(ssa)]
231+
stmt = inst[:inst]
232+
urs = userefs(stmt)
233+
for ur in urs
234+
ur[] = maparg(ur[], SSAValue(ssa), order)
235+
end
236+
inst[:inst] = urs[]
215237
continue
216238
end
217239
if custom
@@ -222,12 +244,16 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
222244
if isexpr(stmt, :invoke)
223245
inst[:inst] = Expr(:call, ∂☆{order}(), map(arg->maparg(arg, SSAValue(ssa), order), stmt.args[2:end])...)
224246
inst[:type] = Any
225-
elseif !isa(stmt, Expr)
226-
inst[:inst] = maparg(stmt, ssa, order)
247+
elseif isexpr(stmt, :call)
248+
inst[:inst] = Expr(:call, ∂☆{order}(), map(arg->maparg(arg, SSAValue(ssa), order), stmt.args)...)
227249
inst[:type] = Any
228250
else
229-
@show stmt
230-
error()
251+
urs = userefs(stmt)
252+
for ur in urs
253+
ur[] = maparg(ur[], SSAValue(ssa), order)
254+
end
255+
inst[:inst] = urs[]
256+
inst[:type] = Any
231257
end
232258
end
233259
end

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ raise_pd(f::PrimeDerivativeFwd{N,T}) where {N,T} = PrimeDerivativeFwd{plus1(N),T
177177

178178
function (f::PrimeDerivativeFwd{1})(x)
179179
z = ∂☆¹(ZeroBundle{1}(getfield(f, :f)), ∂x(x))
180-
z.tangent.partials[1]
180+
z[TaylorTangentIndex(1)]
181181
end
182182

183183
function (f::PrimeDerivativeFwd{N})(x) where N

src/stage1/forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ end
106106
struct ∂☆internal{N}; end
107107
struct ∂☆shuffle{N}; end
108108

109-
shuffle_base(r) = ExplicitTangentBundle{1}(r[1], (r[2],))
109+
shuffle_base(r) = TaylorBundle{1}(r[1], (r[2],))
110110

111111
function (::∂☆internal{1})(args::AbstractTangentBundle{1}...)
112112
r = my_frule(args...)

src/tangent.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,13 @@ function TaylorBundle{N}(primal, coeffs) where {N}
211211
TaylorBundle{N, Core.Typeof(primal)}(primal, coeffs)
212212
end
213213

214+
function Base.show(io::IO, x::TaylorBundle{1})
215+
print(io, x.primal)
216+
print(io, " + ")
217+
x = x.tangent
218+
print(io, x.coeffs[1], " ∂₁")
219+
end
220+
214221
Base.getindex(tb::TaylorBundle, tti::TaylorTangentIndex) = tb.tangent.coeffs[tti.i]
215222
function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
216223
tb.tangent.coeffs[count_ones(tti.i)]

0 commit comments

Comments
 (0)