@@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
4
4
partial (x:: UniformTangent , i) = getfield (x, :val )
5
5
partial (x:: ProductTangent , i) = ProductTangent (map (x-> partial (x, i), getfield (x, :factors )))
6
6
partial (x:: AbstractZero , i) = x
7
- partial (x:: CompositeBundle{N, B} , i) where {N, B<: Tuple } = Tangent {B} (map (x-> partial (x, i), getfield (x, :tup ))... )
8
- function partial (x:: CompositeBundle{N, B} , i) where {N, B}
9
- # This is tangent for a struct, but fields partials are each stored in a plain tuple
10
- # so we add the names back using the primal `B`
11
- # TODO : If required this can be done as a `@generated` function so it is type-stable
12
- backing = NamedTuple {fieldnames(B)} (map (x-> partial (x, i), getfield (x, :tup )))
13
- return Tangent {B, typeof(backing)} (backing)
14
- end
15
7
16
8
17
9
primal (x:: AbstractTangentBundle ) = x. primal
@@ -42,20 +34,12 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
42
34
ntuple (_sdown, N- 1 ))
43
35
end
44
36
45
- function shuffle_down (b:: CompositeBundle{N, B} ) where {N, B}
46
- z = CompositeBundle {N-1, CompositeBundle{1, B}} (
47
- (CompositeBundle {N-1, Tuple} (
48
- map (shuffle_down, b. tup)
49
- ),)
50
- )
51
- z
52
- end
53
37
54
- function shuffle_up (r:: CompositeBundle{1} )
55
- z₀ = primal (r. tup [1 ])
56
- z₁ = partial (r. tup[ 1 ] , 1 )
57
- z₂ = primal (r. tup [2 ])
58
- z₁₂ = partial (r. tup[ 2 ] , 1 )
38
+ function shuffle_up (r:: TaylorBundle{1, Tuple{B1,B2}} ) where {B1,B2}
39
+ z₀ = primal (r) [1 ]
40
+ z₁ = partial (r, 1 )[ 1 ]
41
+ z₂ = primal (r) [2 ]
42
+ z₁₂ = partial (r, 1 )[ 2 ]
59
43
if z₁ == z₂
60
44
return TaylorBundle {2} (z₀, (z₁, z₁₂))
61
45
else
@@ -70,26 +54,33 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
70
54
end
71
55
end
72
56
73
- # Check whether the tangent bundle element is taylor-like
74
- isswifty (:: TaylorBundle ) = true
75
- isswifty (:: UniformBundle ) = true
76
- isswifty (b:: CompositeBundle ) = all (isswifty, b. tup)
77
- isswifty (:: Any ) = false
78
-
79
- function shuffle_up (r:: CompositeBundle{N} ) where {N}
80
- a, b = r. tup
81
- if isswifty (a) && isswifty (b) && taylor_compatible (a, b)
82
- return TaylorBundle {N+1} (primal (a),
83
- ntuple (i-> i == N+ 1 ?
84
- b[TaylorTangentIndex (i- 1 )] : a[TaylorTangentIndex (i)],
85
- N+ 1 ))
57
+ function taylor_compatible (r:: TaylorBundle{N, Tuple{B1,B2}} ) where {N, B1,B2}
58
+ partial (r, 1 )[1 ] == primal (r)[2 ] || return false
59
+ return all (1 : N- 1 ) do i
60
+ partial (r, i+ 1 )[1 ] == partial (r, i)[2 ]
61
+ end
62
+ end
63
+ function shuffle_up (r:: TaylorBundle{N, Tuple{B1,B2}} ) where {N, B1,B2}
64
+ the_primal = primal (r)[1 ]
65
+ if taylor_compatible (r)
66
+ the_partials = ntuple (N+ 1 ) do i
67
+ if i <= N
68
+ partial (r, i)[1 ] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
69
+ else # ii = N+1
70
+ partial (r, i- 1 )[2 ]
71
+ end
72
+ end
73
+ return TaylorBundle {N+1} (the_primal, the_partials)
86
74
else
87
- return TangentBundle {N+1} (r. tup[1 ]. primal,
88
- (r. tup[1 ]. tangent. partials... , primal (b),
89
- ntuple (i-> partial (b,i), 1 << (N+ 1 )- 1 )... ))
75
+ # XXX : am dubious of the correctness of this
76
+ a_partials = ntuple (i-> partial (r, i)[1 ], N)
77
+ b_partials = ntuple (i-> partial (r, i)[2 ], N)
78
+ the_partials = (a_partials... , primal_b, b_partials... )
79
+ return TangentBundle {N+1} (the_primal, the_partials)
90
80
end
91
81
end
92
82
83
+
93
84
function shuffle_up (r:: UniformBundle{N, B, U} ) where {N, B, U}
94
85
(a, b) = primal (r)
95
86
if r. tangent. val === b
185
176
map (y-> lifted_getfield (y, s), x. tangent. coeffs))
186
177
end
187
178
188
- @Base . constprop :aggressive function (:: ∂☆{N})(:: ATB{N, typeof(getfield)} , x:: CompositeBundle{N} , s:: AbstractTangentBundle{N, Int} ) where {N}
189
- x. tup[primal (s)]
190
- end
191
-
192
- @Base . constprop :aggressive function (:: ∂☆{N})(:: ATB{N, typeof(getfield)} , x:: CompositeBundle{N, B} , s:: AbstractTangentBundle{N, Symbol} ) where {N, B}
193
- x. tup[Base. fieldindex (B, primal (s))]
194
- end
195
179
196
180
@Base . constprop :aggressive function (:: ∂☆{N})(f:: ATB{N, typeof(getfield)} , x:: UniformBundle{N, <:Any, U} , s:: AbstractTangentBundle{N} ) where {N, U}
197
181
UniformBundle {N,<:Any,U} (getfield (primal (x), primal (s)), x. tangent. val)
@@ -210,8 +194,8 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}}
210
194
end
211
195
(f:: FwdMap{N} )(args:: AbstractTangentBundle{N} ...) where {N} = ∂☆ {N} ()(f. f, args... )
212
196
213
- function (:: ∂☆{N})(:: ZeroBundle{N, typeof(map)} , f:: ATB{N} , tup:: CompositeBundle {N, <:Tuple} ) where {N}
214
- ∂vararg {N} ()(map (FwdMap (f), tup. tup )... )
197
+ function (:: ∂☆{N})(:: ZeroBundle{N, typeof(map)} , f:: ATB{N} , tup:: TaylorBundle {N, <:Tuple} ) where {N}
198
+ ∂vararg {N} ()(map (FwdMap (f), destructure ( tup) )... )
215
199
end
216
200
217
201
function (:: ∂☆{N})(:: ZeroBundle{N, typeof(map)} , f:: ATB{N} , args:: ATB{N, <:AbstractArray} ...) where {N}
@@ -254,35 +238,37 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate
254
238
Core. _apply_iterate (FwdIterate (iterate), this, (f,), args... )
255
239
end
256
240
257
- function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(iterate)} , t:: CompositeBundle{N, <:Tuple} ) where {N}
258
- r = iterate (t. tup)
241
+
242
+ function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(iterate)} , t:: TaylorBundle{N, <:Tuple} ) where {N}
243
+ r = iterate (destructure (t))
259
244
r === nothing && return ZeroBundle {N} (nothing )
260
245
∂vararg {N} ()(r[1 ], ZeroBundle {N} (r[2 ]))
261
246
end
262
247
263
- function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(iterate)} , t:: CompositeBundle {N, <:Tuple} , a:: ATB{N} , args:: ATB{N} ...) where {N}
264
- r = iterate (t . tup , primal (a), map (primal, args)... )
248
+ function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(iterate)} , t:: TaylorBundle {N, <:Tuple} , a:: ATB{N} , args:: ATB{N} ...) where {N}
249
+ r = iterate (destructure (t) , primal (a), map (primal, args)... )
265
250
r === nothing && return ZeroBundle {N} (nothing )
266
251
∂vararg {N} ()(r[1 ], ZeroBundle {N} (r[2 ]))
267
252
end
268
253
269
- function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(Base.indexed_iterate)} , t:: CompositeBundle {N, <:Tuple} , i:: ATB{N} ) where {N}
270
- r = Base. indexed_iterate (t . tup , primal (i))
254
+ function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(Base.indexed_iterate)} , t:: TaylorBundle {N, <:Tuple} , i:: ATB{N} ) where {N}
255
+ r = Base. indexed_iterate (destructure (t) , primal (i))
271
256
∂vararg {N} ()(r[1 ], ZeroBundle {N} (r[2 ]))
272
257
end
273
258
274
- function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(Base.indexed_iterate)} , t:: CompositeBundle {N, <:Tuple} , i:: ATB{N} , st1:: ATB{N} , st:: ATB{N} ...) where {N}
275
- r = Base. indexed_iterate (t . tup , primal (i), primal (st1), map (primal, st)... )
259
+ function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(Base.indexed_iterate)} , t:: TaylorBundle {N, <:Tuple} , i:: ATB{N} , st1:: ATB{N} , st:: ATB{N} ...) where {N}
260
+ r = Base. indexed_iterate (destructure (t) , primal (i), primal (st1), map (primal, st)... )
276
261
∂vararg {N} ()(r[1 ], ZeroBundle {N} (r[2 ]))
277
262
end
278
263
279
264
function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(Base.indexed_iterate)} , t:: TangentBundle{N, <:Tuple} , i:: ATB{N} , st:: ATB{N} ...) where {N}
280
265
∂vararg {N} ()(this (ZeroBundle {N} (getfield), t, i), ZeroBundle {N} (primal (i) + 1 ))
281
266
end
282
267
283
-
284
- function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(getindex)} , t:: CompositeBundle{N, <:Tuple} , i:: ZeroBundle ) where {N}
285
- t. tup[primal (i)]
268
+ function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(getindex)} , t:: TaylorBundle{N, <:Tuple} , i:: ZeroBundle ) where {N}
269
+ field_ind = primal (i)
270
+ the_partials = ntuple (order_ind-> partial (t, order_ind)[field_ind], N)
271
+ TaylorBundle {N} (primal (t)[field_ind], the_partials)
286
272
end
287
273
288
274
function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(typeof)} , x:: ATB{N} ) where {N}
0 commit comments