Skip to content

Commit 33c3615

Browse files
committed
Handle first order constructors
1 parent 72c1bef commit 33c3615

File tree

5 files changed

+11
-6
lines changed

5 files changed

+11
-6
lines changed

src/codegen/forward.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ function fwd_transform!(ci, mi, nargs, N)
6969
if isa(stmt, Expr)
7070
error("Unexprected statement encountered. This is a bug in Diffractor. stmt=$stmt")
7171
end
72-
return Expr(:call, zero_bundle{order}(), stmt)
72+
return Expr(:call, zero_bundle{N}(), stmt)
7373
end
7474
end
7575

src/codegen/forward_demand.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I
105105
Δbacking = insert_node!(ir, ssa, NewInstruction(Expr(:splatnew, widenconst(tup_typ), Δbacking), tup_typ_typ.val))
106106
end
107107
tangentT = Core.Compiler.apply_type_tfunc(Const(ChainRulesCore.Tangent), newT, tup_typ_typ).val
108+
# TODO do we need to make sure this inserts right
108109
Δtangent = insert_node!(ir, ssa, NewInstruction(Expr(:new, tangentT, Δbacking), tangentT))
109110
return Δtangent
110111
else # general frule handling

src/stage1/generated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ end
390390
lifted_getfield(x::ZeroTangent, s) = ZeroTangent()
391391
lifted_getfield(x::NoTangent, s) = NoTangent()
392392

393-
lifted_getfield(x::Tangent, s) = getproperty(x, s)
393+
lifted_getfield(x::StructuralTangent, s) = getproperty(x, s)
394394

395395
function lifted_getfield(x::Tangent{<:Tangent{T}}, s) where T
396396
bb = getfield(x.backing, 1)

src/stage1/recurse_fwd.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,29 @@ struct ∂☆new{N}; end
1515
function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...)
1616
primal_args = map(primal, xs)
1717
the_primal = _construct(B, primal_args)
18-
18+
@info "∂☆new{1}"
1919
tangent_tup = map(first_partial, xs)
2020
the_partial = if B<:Tuple
2121
Tangent{B, typeof(tangent_tup)}(tangent_tup)
2222
else
2323
names = fieldnames(B)
2424
tangent_nt = NamedTuple{names}(tangent_tup)
25-
Tangent{B, typeof(tangent_nt)}(tangent_nt)
25+
StructuralTangent{B}(tangent_nt)
2626
end
2727
return TaylorBundle{1, B}(the_primal, (the_partial,))
2828
end
2929

3030
function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N}
3131
primal_args = map(primal, xs)
3232
the_primal = _construct(B, primal_args)
33-
33+
@info "∂☆new{N}"
3434
the_partials = ntuple(Val{N}()) do ii
3535
iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking
3636
tangent_tup = map(x->partial(x, ii), xs)
3737
tangent = if B<:Tuple
3838
Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup)
3939
else
40+
# TODO support mutation
4041
names = fieldnames(B)
4142
tangent_nt = NamedTuple{names}(tangent_tup)
4243
Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt)

test/forward_mutation.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
using Diffractor
2-
using Diffractor: ∂☆
2+
using Diffractor: ∂☆, ZeroBundle, TaylorBundle
33
using Diffractor: bundle
44

55
mutable struct MDemo1
66
x::Float64
77
end
8+
9+
∂☆{1}()(ZeroBundle{1}(MDemo1), TaylorBundle{1}(1.5, (1.0,)))
10+
811
function double!(val::MDemo1)
912
val.x *= 2.0
1013
return val

0 commit comments

Comments
 (0)