Skip to content

Commit c205e34

Browse files
committed
Remove most uses of CompositeTangent
1 parent be091aa commit c205e34

File tree

5 files changed

+25
-87
lines changed

5 files changed

+25
-87
lines changed

src/AbstractDifferentiation.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,7 @@ This is more or less the Diffractor equivelent of ForwardDiff.jl's `Dual` type.
1010
"""
1111
function bundle end
1212
bundle(x, dx::ChainRulesCore.AbstractZero) = UniformBundle{1, typeof(x), typeof(dx)}(x, dx)
13-
bundle(x::Number, dx::Number) = TaylorBundle{1}(x, (dx,))
14-
bundle(x::AbstractArray{<:Number}, dx::AbstractArray{<:Number}) = TaylorBundle{1}(x, (dx,))
15-
bundle(x::P, dx::Tangent{P}) where P = _bundle(x, ChainRulesCore.canonicalize(dx))
16-
17-
"helper that assumes tangent is in canonical form"
18-
function _bundle(x::P, dx::Tangent{P}) where P
19-
# SoA to AoS flip (hate this, hate it even more cos we just undo it later when we hit chainrules)
20-
the_bundle = ntuple(Val{fieldcount(P)}()) do ii
21-
bundle(getfield(x, ii), getproperty(dx, ii))
22-
end
23-
return CompositeBundle{1, P}(the_bundle)
24-
end
25-
13+
bundle(x, dx) = TaylorBundle{1}(x, (dx,))
2614

2715
AD.@primitive function pushforward_function(b::DiffractorForwardBackend, f, args...)
2816
return function pushforward(vs)

src/stage1/forward.jl

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
44
partial(x::UniformTangent, i) = getfield(x, :val)
55
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
66
partial(x::AbstractZero, i) = x
7-
partial(x::CompositeBundle{N, B}, i) where {N, B<:Tuple} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...)
8-
function partial(x::CompositeBundle{N, B}, i) where {N, B}
9-
# This is tangent for a struct, but fields partials are each stored in a plain tuple
10-
# so we add the names back using the primal `B`
11-
# TODO: If required this can be done as a `@generated` function so it is type-stable
12-
backing = NamedTuple{fieldnames(B)}(map(x->partial(x, i), getfield(x, :tup)))
13-
return Tangent{B, typeof(backing)}(backing)
14-
end
157

168

179
primal(x::AbstractTangentBundle) = x.primal
@@ -42,14 +34,6 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
4234
ntuple(_sdown, N-1))
4335
end
4436

45-
function shuffle_down(b::CompositeBundle{N, B}) where {N, B}
46-
z = CompositeBundle{N-1, CompositeBundle{1, B}}(
47-
(CompositeBundle{N-1, Tuple}(
48-
map(shuffle_down, b.tup)
49-
),)
50-
)
51-
z
52-
end
5337

5438
function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2}
5539
z₀ = primal(r)[1]
@@ -63,18 +47,7 @@ function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2}
6347
end
6448
end
6549

66-
function shuffle_up(r::CompositeBundle{1})
67-
z₀ = primal(r.tup[1])
68-
z₁ = partial(r.tup[1], 1)
69-
z₂ = primal(r.tup[2])
70-
z₁₂ = partial(r.tup[2], 1)
71-
if z₁ == z₂
72-
return TaylorBundle{2}(z₀, (z₁, z₁₂))
73-
else
74-
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
75-
end
76-
end
77-
50+
#==
7851
function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
7952
primal(b) === a[TaylorTangentIndex(1)] || return false
8053
return all(1:(N-1)) do i
@@ -88,7 +61,7 @@ isswifty(::UniformBundle) = true
8861
isswifty(b::CompositeBundle) = all(isswifty, b.tup)
8962
isswifty(::Any) = false
9063
91-
#TODO: port this to TaylorTangent:
64+
#TODO: port this to TaylorTangent over composite structures
9265
function shuffle_up(r::CompositeBundle{N}) where {N}
9366
a, b = r.tup
9467
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
@@ -102,6 +75,7 @@ function shuffle_up(r::CompositeBundle{N}) where {N}
10275
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
10376
end
10477
end
78+
==#
10579

10680
function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
10781
(a, b) = primal(r)
@@ -198,13 +172,6 @@ end
198172
map(y->lifted_getfield(y, s), x.tangent.coeffs))
199173
end
200174

201-
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
202-
x.tup[primal(s)]
203-
end
204-
205-
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
206-
x.tup[Base.fieldindex(B, primal(s))]
207-
end
208175

209176
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
210177
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.tangent.val)
@@ -223,9 +190,12 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}}
223190
end
224191
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)
225192

193+
#==
194+
# TODO port this to TaylorBundle over composite structure
226195
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N}
227196
∂vararg{N}()(map(FwdMap(f), tup.tup)...)
228197
end
198+
==#
229199

230200
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
231201
# TODO: This could do an inplace map! to avoid the extra rebundling
@@ -267,23 +237,28 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate
267237
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
268238
end
269239

240+
#==
241+
#TODO: port this to TaylorTangent over composite structures
270242
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N}
271243
r = iterate(t.tup)
272244
r === nothing && return ZeroBundle{N}(nothing)
273245
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
274246
end
275247
248+
#TODO: port this to TaylorTangent over composite structures
276249
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
277250
r = iterate(t.tup, primal(a), map(primal, args)...)
278251
r === nothing && return ZeroBundle{N}(nothing)
279252
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
280253
end
281254
255+
#TODO: port this to TaylorTangent over composite structures
282256
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N}
283257
r = Base.indexed_iterate(t.tup, primal(i))
284258
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
285259
end
286260
261+
#TODO: port this to TaylorTangent over composite structures
287262
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
288263
r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...)
289264
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
@@ -293,10 +268,11 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::Tan
293268
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1))
294269
end
295270
296-
271+
#TODO: port this to TaylorTangent over composite structures
297272
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N}
298273
t.tup[primal(i)]
299274
end
275+
==#
300276

301277
function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}
302278
DNEBundle{N}(typeof(primal(x)))

src/tangent.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -290,33 +290,6 @@ end
290290

291291
Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val
292292

293-
"""
294-
CompositeBundle{N, B, B <: Tuple}
295-
296-
Represents the tagent bundle where the base space is some tuple or struct type.
297-
Mathematically, this tangent bundle is the product bundle of the individual
298-
element bundles.
299-
"""
300-
struct CompositeBundle{N, B, T<:Tuple{Vararg{AbstractTangentBundle{N}}}} <: AbstractTangentBundle{N, B}
301-
tup::T
302-
end
303-
CompositeBundle{N, B}(tup::T) where {N, B, T} = CompositeBundle{N, B, T}(tup)
304-
305-
function Base.getindex(tb::CompositeBundle{N, B} where N, tti::TaylorTangentIndex) where {B}
306-
B <: SArray && error()
307-
return partial(tb, tti.i)
308-
end
309-
310-
primal(b::CompositeBundle{N, <:Tuple} where N) = map(primal, b.tup)
311-
function primal(b::CompositeBundle{N, T} where N) where T<:CompositeBundle
312-
T(map(primal, b.tup)...)
313-
end
314-
@generated primal(b::CompositeBundle{N, B} where N) where {B} =
315-
quote
316-
x = map(primal, b.tup)
317-
$(Expr(:splatnew, B, :x))
318-
end
319-
320293
expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...)
321294
expand_singleton_to_array(asize, a::AbstractArray) = a
322295

test/AbstractDifferentiationTests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ backend = Diffractor.DiffractorForwardBackend()
77

88
@test bundle(1.0, 2.0) isa Diffractor.TaylorBundle{1}
99
@test bundle([1.0, 2.0], [2.0, 3.0]) isa Diffractor.TaylorBundle{1}
10-
@test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.CompositeBundle{1}
10+
@test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.TaylorBundle{1}
1111
@test bundle(1.1, ChainRulesCore.ZeroTangent()) isa Diffractor.ZeroBundle{1}
12-
@test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.CompositeBundle{1}
12+
@test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.TaylorBundle{1}
1313

1414
# noncanonical structural tangent
1515
b = bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(second=Tangent{Pair{Float64, Float64}}(second=2.0, first=1.0)))
1616
t = Diffractor.first_partial(b)
17-
@test b isa Diffractor.CompositeBundle{1}
17+
@test b isa Diffractor.TaylorBundle{1}
1818
@test iszero(t.first)
1919
@test t.second.first == 1.0
2020
@test t.second.second == 2.0

test/tangent.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module tagent
22
using Diffractor
33
using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle
4-
using Diffractor: TaylorBundle, TaylorTangentIndex, CompositeBundle
4+
using Diffractor: TaylorBundle, TaylorTangentIndex
55
using Diffractor: ExplicitTangent, TaylorTangent, truncate
66
using ChainRulesCore
77
using Test
@@ -28,21 +28,22 @@ using Test
2828
end
2929

3030
@testset "AD through constructor" begin
31-
#https://github.com/JuliaDiff/Diffractor.jl/issues/152
32-
# hits `getindex(::CompositeBundle{Foo152}, ::TaylorTangentIndex)`
31+
# https://github.com/JuliaDiff/Diffractor.jl/issues/152
32+
# Though we have now removed the underlying cause, we keep this as a regression test just in case
3333
struct Foo152
3434
x::Float64
3535
end
3636

3737
# Unit Test
38-
cb = CompositeBundle{1, Foo152}((TaylorBundle{1, Float64}(23.5, (1.0,)),))
38+
cb = TaylorBundle{1, Foo152}(Foo152(23.5), (Tangent{Foo152}(;x=1.0),))
3939
tti = TaylorTangentIndex(1,)
4040
@test cb[tti] == Tangent{Foo152}(; x=1.0)
4141

4242
# Integration Test
43-
var"'" = Diffractor.PrimeDerivativeFwd
44-
f(x) = Foo152(x)
45-
@test f'(23.5) == Tangent{Foo152}(; x=1.0)
43+
let var"'" = Diffractor.PrimeDerivativeFwd
44+
f(x) = Foo152(x)
45+
@test f'(23.5) == Tangent{Foo152}(; x=1.0)
46+
end
4647
end
4748

4849
@testset "truncate" begin

0 commit comments

Comments
 (0)