Skip to content

Commit 72c1bef

Browse files
committed
WIP: begin switching forward mode over to zero_bundles for mutation support
1 parent cbcc0f3 commit 72c1bef

File tree

7 files changed

+54
-7
lines changed

7 files changed

+54
-7
lines changed

src/codegen/forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function fwd_transform!(ci, mi, nargs, N)
6969
if isa(stmt, Expr)
7070
error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt")
7171
end
72-
return Expr(:call, ZeroBundle{N}, stmt)
72+
return Expr(:call, zero_bundle{order}(), stmt)
7373
end
7474
end
7575

src/codegen/forward_demand.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
264264
return transform!(ir, arg, order, maparg)
265265
elseif isa(arg, GlobalRef)
266266
@assert isconst(arg)
267-
return ZeroBundle{order}(getfield(arg.mod, arg.name))
267+
return zero_bundle{order}()(getfield(arg.mod, arg.name))
268268
elseif isa(arg, QuoteNode)
269-
return ZeroBundle{order}(arg.value)
269+
return zero_bundle{order}(){order}(arg.value)
270270
end
271271
@assert !isa(arg, Expr)
272-
return ZeroBundle{order}(arg)
272+
return zero_bundle{order}()(arg)
273273
end
274274

275275
for (ssa, (order, custom)) in enumerate(ssa_orders)
@@ -309,7 +309,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
309309
stmt = insert_node!(ir, ssa, NewInstruction(inst))
310310
end
311311

312-
replace_call!(ir, SSAValue(ssa), Expr(:call, ZeroBundle{order}, stmt))
312+
replace_call!(ir, SSAValue(ssa), Expr(:call, zero_bundle{order}(), stmt))
313313
elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode)
314314
inst[:inst] = maparg(stmt, SSAValue(ssa), order)
315315
inst[:type] = Any
@@ -329,7 +329,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
329329
inst[:type] = Any
330330
inst[:flag] |= CC.IR_FLAG_REFINED
331331
else
332-
val = ZeroBundle{order}(inst[:inst])
332+
val = zero_bundle{order}()(inst[:inst])
333333
inst[:inst] = val
334334
inst[:type] = Const(val)
335335
end

src/stage1/recurse_fwd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ _construct(::Type{B}, args) where B<:Tuple = B(args)
5050
# Hack for making things that do not have public constructors constructable:
5151
@generated _construct(B::Type, args) = Expr(:splatnew, :B, :args)
5252

53-
@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B))))
53+
@generated (::∂☆new{N})(B::Type) where {N} = return :(zero_bundle{order}()($(Expr(:new, :B))))
5454

5555
# Sometimes we don't know whether or not we need to the ZeroBundle when doing
5656
# the transform, so this can happen - allow it for now.

src/tangent.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,3 +418,18 @@ end
418418
function ChainRulesCore.rrule(::typeof(rebundle), atb)
419419
rebundle(atb), Δ->throw(Δ)
420420
end
421+
422+
423+
"""
424+
(::zero_bundle{N})(primal)
425+
426+
Creates a bundle with a zero tangent.
427+
"""
428+
struct zero_bundle{N} end
429+
function (::zero_bundle{N})(primal) where N
430+
if zero_tangent(primal) isa ZeroTangent
431+
return ZeroBundle{N}(primal)
432+
else
433+
return TaylorBundle{N}(primal, ntuple(_->zero_tangent(primal), N))
434+
end
435+
end

test/forward_mutation.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using Diffractor
2+
using Diffractor: ∂☆
3+
using Diffractor: bundle
4+
5+
mutable struct MDemo1
6+
x::Float64
7+
end
8+
function double!(val::MDemo1)
9+
val.x *= 2.0
10+
return val
11+
end
12+
function wrap_and_double(x)
13+
val = MDemo1(x)
14+
double!(val)
15+
end
16+
∂☆{1}()(ZeroBundle{1}(wrap_and_double), TaylorBundle{1}(1.5, (1.0,)))

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ const bwd = Diffractor.PrimeDerivativeBack
1818
"tangent.jl",
1919
"forward_diff_no_inf.jl",
2020
"forward.jl",
21+
"forward_mutation.jl",
2122
"reverse.jl",
2223
"regression.jl",
2324
"AbstractDifferentiationTests.jl"

test/tangent.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,19 @@ end
122122
@test truncate(et, Val(1)) == TaylorTangent((1.0,))
123123
end
124124

125+
@testset "zero_bundle" begin
126+
zero_bundle = Diffractor.zero_bundle
127+
128+
tup_zb = zero_bundle{2}()((1, 0))
129+
@test tup_zb isa ZeroBundle{2}
130+
@test iszero(tup_zb[TaylorTangentIndex(1)])
131+
@test iszero(tup_zb[TaylorTangentIndex(2)])
132+
133+
134+
ref_zb = zero_bundle{2}()(Ref(1.5))
135+
@test ref_zb isa TaylorBundle{2}
136+
@test iszero(ref_zb[TaylorTangentIndex(1)])
137+
@test iszero(ref_zb[TaylorTangentIndex(2)])
138+
end
139+
125140
end # module

0 commit comments

Comments
 (0)