Skip to content

Commit c444a7a

Browse files
committed
make 2nd order mutation not error
1 parent 33c3615 commit c444a7a

File tree

3 files changed

+54
-7
lines changed

3 files changed

+54
-7
lines changed

β€Žsrc/extra_rules.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,18 @@ Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDi
262262

263263
# Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495
264264
ChainRulesCore._backing_error(P::Type{<:Base.Pairs}, G::Type{<:NamedTuple}, E::Type{<:AbstractDict}) = nothing
265+
266+
# Needed for higher order so we don't see the `backing` field of StructuralTangents, just the contents
267+
# SHould these be in ChainRules/ChainRulesCore?
268+
# is this always the right behavour, or just because of how we do higher order
269+
function ChainRulesCore.frule((_, Ξ”, _, _), ::typeof(getproperty), strct::StructuralTangent, sym::Union{Int,Symbol}, inbounds)
270+
return (getproperty(strct, sym, inbounds), getproperty(Ξ”, sym))
271+
end
272+
273+
274+
function ChainRulesCore.frule((_, ȯbj, _, ẋ), ::typeof(setproperty!), obj::MutableTangent, field, x)
275+
ȯbj::MutableTangent
276+
y = setproperty!(obj, field, x)
277+
ẏ = setproperty!(ȯbj, field, ẋ)
278+
return y, ẏ
279+
end

β€Žsrc/stage1/recurse_fwd.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ function (::βˆ‚β˜†new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N}
3232
the_primal = _construct(B, primal_args)
3333
@info "βˆ‚β˜†new{N}"
3434
the_partials = ntuple(Val{N}()) do ii
35-
iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking
3635
tangent_tup = map(x->partial(x, ii), xs)
3736
tangent = if B<:Tuple
38-
Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup)
37+
Tangent{B, typeof(tangent_tup)}(tangent_tup)
3938
else
40-
# TODO support mutation
39+
# It is a little dubious using StructuralTangent{B} for >1st order, but it is isomorphic.
40+
# Just watch out for order mixing bugs.
4141
names = fieldnames(B)
4242
tangent_nt = NamedTuple{names}(tangent_tup)
43-
Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt)
43+
StructuralTangent{B}(tangent_nt)
4444
end
4545
return tangent
4646
end

β€Žtest/forward_mutation.jl

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
1+
# module forward_mutation
12
using Diffractor
23
using Diffractor: βˆ‚β˜†, ZeroBundle, TaylorBundle
3-
using Diffractor: bundle
4+
using Diffractor: bundle, first_partial, TaylorTangentIndex
5+
using ChainRulesCore
6+
using Test
7+
48

59
mutable struct MDemo1
610
x::Float64
711
end
812

9-
βˆ‚β˜†{1}()(ZeroBundle{1}(MDemo1), TaylorBundle{1}(1.5, (1.0,)))
13+
@testset "construction" begin
14+
🍞 = βˆ‚β˜†{1}()(ZeroBundle{1}(MDemo1), TaylorBundle{1}(1.5, (1.0,)))
15+
@test 🍞[TaylorTangentIndex(1)] isa MutableTangent{MDemo1}
16+
@test 🍞[TaylorTangentIndex(1)].x == 1.0
17+
18+
πŸ₯― = βˆ‚β˜†{2}()(ZeroBundle{2}(MDemo1), TaylorBundle{2}(1.5, (1.0, 10.0)))
19+
@test πŸ₯―[TaylorTangentIndex(1)] isa MutableTangent{MDemo1}
20+
@test πŸ₯―[TaylorTangentIndex(1)].x == 1.0
21+
@test πŸ₯―[TaylorTangentIndex(2)] isa MutableTangent
22+
@test πŸ₯―[TaylorTangentIndex(2)].x == 10.0
23+
end
1024

1125
function double!(val::MDemo1)
1226
val.x *= 2.0
@@ -16,4 +30,22 @@ function wrap_and_double(x)
1630
val = MDemo1(x)
1731
double!(val)
1832
end
19-
βˆ‚β˜†{1}()(ZeroBundle{1}(wrap_and_double), TaylorBundle{1}(1.5, (1.0,)))
33+
🐰 = βˆ‚β˜†{1}()(ZeroBundle{1}(wrap_and_double), TaylorBundle{1}(1.5, (1.0,)))
34+
@test first_partial(🐰) isa MutableTangent{MDemo1}
35+
@test first_partial(🐰).x == 2.0
36+
37+
# second derivative
38+
πŸ‡ = βˆ‚β˜†{2}()(ZeroBundle{2}(wrap_and_double), TaylorBundle{2}(1.5, (1.0, 10.0)))
39+
@test πŸ‡[TaylorTangentIndex(1)] isa MutableTangent{MDemo1}
40+
@test πŸ‡[TaylorTangentIndex(1)].x == 2.0
41+
@test πŸ‡[TaylorTangentIndex(2)] isa MutableTangent
42+
@test πŸ‡[TaylorTangentIndex(2)] == 0.0 # returns 20
43+
44+
45+
46+
foo(val) = val^2
47+
πŸ₯– = βˆ‚β˜†{2}()(ZeroBundle{2}(foo), TaylorBundle{2}(1.0, (0.0, 10.0)))
48+
πŸ₯–[TaylorTangentIndex(1)] # returns 0
49+
πŸ₯–[TaylorTangentIndex(2)] # returns 20
50+
51+
# end # module

0 commit comments

Comments
Β (0)