Skip to content

Commit ffd8840

Browse files
authored
Merge pull request #216 from JuliaDiff/ox/nocompo
Remove Composite Bundle
2 parents f63ad15 + 40da34b commit ffd8840

File tree

8 files changed

+196
-124
lines changed

8 files changed

+196
-124
lines changed

src/AbstractDifferentiation.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,7 @@ This is more or less the Diffractor equivelent of ForwardDiff.jl's `Dual` type.
1010
"""
1111
function bundle end
1212
bundle(x, dx::ChainRulesCore.AbstractZero) = UniformBundle{1, typeof(x), typeof(dx)}(x, dx)
13-
bundle(x::Number, dx::Number) = TaylorBundle{1}(x, (dx,))
14-
bundle(x::AbstractArray{<:Number}, dx::AbstractArray{<:Number}) = TaylorBundle{1}(x, (dx,))
15-
bundle(x::P, dx::Tangent{P}) where P = _bundle(x, ChainRulesCore.canonicalize(dx))
16-
17-
"helper that assumes tangent is in canonical form"
18-
function _bundle(x::P, dx::Tangent{P}) where P
19-
# SoA to AoS flip (hate this, hate it even more cos we just undo it later when we hit chainrules)
20-
the_bundle = ntuple(Val{fieldcount(P)}()) do ii
21-
bundle(getfield(x, ii), getproperty(dx, ii))
22-
end
23-
return CompositeBundle{1, P}(the_bundle)
24-
end
25-
13+
bundle(x, dx) = TaylorBundle{1}(x, (dx,))
2614

2715
AD.@primitive function pushforward_function(b::DiffractorForwardBackend, f, args...)
2816
return function pushforward(vs)

src/stage1/compiler_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ function Base.push!(cfg::CFG, bb::BasicBlock)
77
push!(cfg.index, bb.stmts.start)
88
end
99

10-
if VERSION <= v"1.11.0-DEV.116"
10+
if VERSION < v"1.11.0-DEV.258"
1111
Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa)
1212
end
1313

src/stage1/forward.jl

Lines changed: 43 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
44
partial(x::UniformTangent, i) = getfield(x, :val)
55
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
66
partial(x::AbstractZero, i) = x
7-
partial(x::CompositeBundle{N, B}, i) where {N, B<:Tuple} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...)
8-
function partial(x::CompositeBundle{N, B}, i) where {N, B}
9-
# This is tangent for a struct, but fields partials are each stored in a plain tuple
10-
# so we add the names back using the primal `B`
11-
# TODO: If required this can be done as a `@generated` function so it is type-stable
12-
backing = NamedTuple{fieldnames(B)}(map(x->partial(x, i), getfield(x, :tup)))
13-
return Tangent{B, typeof(backing)}(backing)
14-
end
157

168

179
primal(x::AbstractTangentBundle) = x.primal
@@ -42,20 +34,12 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
4234
ntuple(_sdown, N-1))
4335
end
4436

45-
function shuffle_down(b::CompositeBundle{N, B}) where {N, B}
46-
z = CompositeBundle{N-1, CompositeBundle{1, B}}(
47-
(CompositeBundle{N-1, Tuple}(
48-
map(shuffle_down, b.tup)
49-
),)
50-
)
51-
z
52-
end
5337

54-
function shuffle_up(r::CompositeBundle{1})
55-
z₀ = primal(r.tup[1])
56-
z₁ = partial(r.tup[1], 1)
57-
z₂ = primal(r.tup[2])
58-
z₁₂ = partial(r.tup[2], 1)
38+
function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2}
39+
z₀ = primal(r)[1]
40+
z₁ = partial(r, 1)[1]
41+
z₂ = primal(r)[2]
42+
z₁₂ = partial(r, 1)[2]
5943
if z₁ == z₂
6044
return TaylorBundle{2}(z₀, (z₁, z₁₂))
6145
else
@@ -70,26 +54,33 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
7054
end
7155
end
7256

73-
# Check whether the tangent bundle element is taylor-like
74-
isswifty(::TaylorBundle) = true
75-
isswifty(::UniformBundle) = true
76-
isswifty(b::CompositeBundle) = all(isswifty, b.tup)
77-
isswifty(::Any) = false
78-
79-
function shuffle_up(r::CompositeBundle{N}) where {N}
80-
a, b = r.tup
81-
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
82-
return TaylorBundle{N+1}(primal(a),
83-
ntuple(i->i == N+1 ?
84-
b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)],
85-
N+1))
57+
function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
58+
partial(r, 1)[1] == primal(r)[2] || return false
59+
return all(1:N-1) do i
60+
partial(r, i+1)[1] == partial(r, i)[2]
61+
end
62+
end
63+
function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
64+
the_primal = primal(r)[1]
65+
if taylor_compatible(r)
66+
the_partials = ntuple(N+1) do i
67+
if i <= N
68+
partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
69+
else # ii = N+1
70+
partial(r, i-1)[2]
71+
end
72+
end
73+
return TaylorBundle{N+1}(the_primal, the_partials)
8674
else
87-
return TangentBundle{N+1}(r.tup[1].primal,
88-
(r.tup[1].tangent.partials..., primal(b),
89-
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
75+
#XXX: am dubious of the correctness of this
76+
a_partials = ntuple(i->partial(r, i)[1], N)
77+
b_partials = ntuple(i->partial(r, i)[2], N)
78+
the_partials = (a_partials..., primal_b, b_partials...)
79+
return TangentBundle{N+1}(the_primal, the_partials)
9080
end
9181
end
9282

83+
9384
function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
9485
(a, b) = primal(r)
9586
if r.tangent.val === b
@@ -185,13 +176,6 @@ end
185176
map(y->lifted_getfield(y, s), x.tangent.coeffs))
186177
end
187178

188-
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
189-
x.tup[primal(s)]
190-
end
191-
192-
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
193-
x.tup[Base.fieldindex(B, primal(s))]
194-
end
195179

196180
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
197181
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.tangent.val)
@@ -210,8 +194,8 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}}
210194
end
211195
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)
212196

213-
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N}
214-
∂vararg{N}()(map(FwdMap(f), tup.tup)...)
197+
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N}
198+
∂vararg{N}()(map(FwdMap(f), destructure(tup))...)
215199
end
216200

217201
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
@@ -254,35 +238,37 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate
254238
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
255239
end
256240

257-
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N}
258-
r = iterate(t.tup)
241+
242+
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N}
243+
r = iterate(destructure(t))
259244
r === nothing && return ZeroBundle{N}(nothing)
260245
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
261246
end
262247

263-
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
264-
r = iterate(t.tup, primal(a), map(primal, args)...)
248+
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
249+
r = iterate(destructure(t), primal(a), map(primal, args)...)
265250
r === nothing && return ZeroBundle{N}(nothing)
266251
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
267252
end
268253

269-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N}
270-
r = Base.indexed_iterate(t.tup, primal(i))
254+
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N}
255+
r = Base.indexed_iterate(destructure(t), primal(i))
271256
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
272257
end
273258

274-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
275-
r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...)
259+
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
260+
r = Base.indexed_iterate(destructure(t), primal(i), primal(st1), map(primal, st)...)
276261
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
277262
end
278263

279264
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N}
280265
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1))
281266
end
282267

283-
284-
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N}
285-
t.tup[primal(i)]
268+
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N}
269+
field_ind = primal(i)
270+
the_partials = ntuple(order_ind->partial(t, order_ind)[field_ind], N)
271+
TaylorBundle{N}(primal(t)[field_ind], the_partials)
286272
end
287273

288274
function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}

src/stage1/recurse_fwd.jl

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,51 @@ struct ∂vararg{N}; end
44

55
(::∂vararg{N})() where {N} = ZeroBundle{N}(())
66
function (::∂vararg{N})(a::AbstractTangentBundle{N}...) where N
7-
CompositeBundle{N, Tuple{map(x->basespace(typeof(x)), a)...}}(a)
7+
B = Tuple{map(x->basespace(Core.Typeof(x)), a)...}
8+
return (∂☆new{N}())(B, a...)
89
end
910

1011
struct ∂☆new{N}; end
1112

12-
(::∂☆new{N})(B::Type, a::AbstractTangentBundle{N}...) where {N} =
13-
CompositeBundle{N, B}(a)
13+
# we split out the 1st order derivative as a special case for performance
14+
# but the nth order case does also work for this
15+
function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...)
16+
primal_args = map(primal, xs)
17+
the_primal = _construct(B, primal_args)
18+
19+
tangent_tup = map(first_partial, xs)
20+
the_partial = if B<:Tuple
21+
Tangent{B, typeof(tangent_tup)}(tangent_tup)
22+
else
23+
names = fieldnames(B)
24+
tangent_nt = NamedTuple{names}(tangent_tup)
25+
Tangent{B, typeof(tangent_nt)}(tangent_nt)
26+
end
27+
return TaylorBundle{1, B}(the_primal, (the_partial,))
28+
end
29+
30+
function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N}
31+
primal_args = map(primal, xs)
32+
the_primal = _construct(B, primal_args)
33+
34+
the_partials = ntuple(Val{N}()) do ii
35+
iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking
36+
tangent_tup = map(x->partial(x, ii), xs)
37+
tangent = if B<:Tuple
38+
Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup)
39+
else
40+
names = fieldnames(B)
41+
tangent_nt = NamedTuple{names}(tangent_tup)
42+
Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt)
43+
end
44+
return tangent
45+
end
46+
return TaylorBundle{N, B}(the_primal, the_partials)
47+
end
48+
49+
_construct(::Type{B}, args) where B<:Tuple = B(args)
50+
# Hack for making things that do not have public constructors constructable:
51+
@generated _construct(B::Type, args) = Expr(:splatnew, :B, :args)
1452

1553
@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B))))
1654

src/tangent.jl

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -208,9 +208,7 @@ end
208208

209209
function check_taylor_invariants(coeffs, primal, N)
210210
@assert length(coeffs) == N
211-
if isa(primal, TangentBundle)
212-
@assert isa(coeffs[1], TangentBundle)
213-
end
211+
214212
end
215213
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N)
216214

@@ -230,6 +228,18 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
230228
tb.tangent.coeffs[count_ones(tti.i)]
231229
end
232230

231+
"for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple"
232+
function destructure(r::TaylorBundle{N, B}) where {N, B<:Tuple}
233+
return ntuple(fieldcount(B)) do field_ii
234+
the_primal = primal(r)[field_ii]
235+
the_partials = ntuple(N) do order_ii
236+
partial(r, order_ii)[field_ii]
237+
end
238+
return TaylorBundle{N}(the_primal, the_partials)
239+
end
240+
end
241+
242+
233243
function truncate(tt::TaylorTangent, order::Val{N}) where {N}
234244
TaylorTangent(tt.coeffs[1:N])
235245
end
@@ -290,33 +300,6 @@ end
290300

291301
Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val
292302

293-
"""
294-
CompositeBundle{N, B, B <: Tuple}
295-
296-
Represents the tagent bundle where the base space is some tuple or struct type.
297-
Mathematically, this tangent bundle is the product bundle of the individual
298-
element bundles.
299-
"""
300-
struct CompositeBundle{N, B, T<:Tuple{Vararg{AbstractTangentBundle{N}}}} <: AbstractTangentBundle{N, B}
301-
tup::T
302-
end
303-
CompositeBundle{N, B}(tup::T) where {N, B, T} = CompositeBundle{N, B, T}(tup)
304-
305-
function Base.getindex(tb::CompositeBundle{N, B} where N, tti::TaylorTangentIndex) where {B}
306-
B <: SArray && error()
307-
return partial(tb, tti.i)
308-
end
309-
310-
primal(b::CompositeBundle{N, <:Tuple} where N) = map(primal, b.tup)
311-
function primal(b::CompositeBundle{N, T} where N) where T<:CompositeBundle
312-
T(map(primal, b.tup)...)
313-
end
314-
@generated primal(b::CompositeBundle{N, B} where N) where {B} =
315-
quote
316-
x = map(primal, b.tup)
317-
$(Expr(:splatnew, B, :x))
318-
end
319-
320303
expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...)
321304
expand_singleton_to_array(asize, a::AbstractArray) = a
322305

test/AbstractDifferentiationTests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ backend = Diffractor.DiffractorForwardBackend()
77

88
@test bundle(1.0, 2.0) isa Diffractor.TaylorBundle{1}
99
@test bundle([1.0, 2.0], [2.0, 3.0]) isa Diffractor.TaylorBundle{1}
10-
@test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.CompositeBundle{1}
10+
@test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.TaylorBundle{1}
1111
@test bundle(1.1, ChainRulesCore.ZeroTangent()) isa Diffractor.ZeroBundle{1}
12-
@test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.CompositeBundle{1}
12+
@test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.TaylorBundle{1}
1313

1414
# noncanonical structural tangent
1515
b = bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(second=Tangent{Pair{Float64, Float64}}(second=2.0, first=1.0)))
1616
t = Diffractor.first_partial(b)
17-
@test b isa Diffractor.CompositeBundle{1}
17+
@test b isa Diffractor.TaylorBundle{1}
1818
@test iszero(t.first)
1919
@test t.second.first == 1.0
2020
@test t.second.second == 2.0
@@ -29,7 +29,7 @@ end
2929

3030
# standard tests from AbstractDifferentiation.test_utils
3131
include(joinpath(pathof(AbstractDifferentiation), "..", "..", "test", "test_utils.jl"))
32-
@testset "ForwardDiffBackend" begin
32+
@testset "Standard AbstractDifferentiation.test_utils tests" begin
3333
backends = [
3434
@inferred(Diffractor.DiffractorForwardBackend())
3535
]

0 commit comments

Comments
 (0)