@@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
4
4
partial (x:: UniformTangent , i) = getfield (x, :val )
5
5
partial (x:: ProductTangent , i) = ProductTangent (map (x-> partial (x, i), getfield (x, :factors )))
6
6
partial (x:: AbstractZero , i) = x
7
- partial (x:: CompositeBundle{N, B} , i) where {N, B<: Tuple } = Tangent {B} (map (x-> partial (x, i), getfield (x, :tup ))... )
8
- function partial (x:: CompositeBundle{N, B} , i) where {N, B}
9
- # This is tangent for a struct, but fields partials are each stored in a plain tuple
10
- # so we add the names back using the primal `B`
11
- # TODO : If required this can be done as a `@generated` function so it is type-stable
12
- backing = NamedTuple {fieldnames(B)} (map (x-> partial (x, i), getfield (x, :tup )))
13
- return Tangent {B, typeof(backing)} (backing)
14
- end
15
7
16
8
17
9
primal (x:: AbstractTangentBundle ) = x. primal
@@ -42,14 +34,6 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
42
34
ntuple (_sdown, N- 1 ))
43
35
end
44
36
45
- function shuffle_down (b:: CompositeBundle{N, B} ) where {N, B}
46
- z = CompositeBundle {N-1, CompositeBundle{1, B}} (
47
- (CompositeBundle {N-1, Tuple} (
48
- map (shuffle_down, b. tup)
49
- ),)
50
- )
51
- z
52
- end
53
37
54
38
function shuffle_up (r:: TaylorBundle{1, Tuple{B1,B2}} ) where {B1,B2}
55
39
z₀ = primal (r)[1 ]
@@ -63,18 +47,7 @@ function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2}
63
47
end
64
48
end
65
49
66
- function shuffle_up (r:: CompositeBundle{1} )
67
- z₀ = primal (r. tup[1 ])
68
- z₁ = partial (r. tup[1 ], 1 )
69
- z₂ = primal (r. tup[2 ])
70
- z₁₂ = partial (r. tup[2 ], 1 )
71
- if z₁ == z₂
72
- return TaylorBundle {2} (z₀, (z₁, z₁₂))
73
- else
74
- return ExplicitTangentBundle {2} (z₀, (z₁, z₂, z₁₂))
75
- end
76
- end
77
-
50
+ #= =
78
51
function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
79
52
primal(b) === a[TaylorTangentIndex(1)] || return false
80
53
return all(1:(N-1)) do i
@@ -88,7 +61,7 @@ isswifty(::UniformBundle) = true
88
61
isswifty(b::CompositeBundle) = all(isswifty, b.tup)
89
62
isswifty(::Any) = false
90
63
91
- # TODO : port this to TaylorTangent:
64
+ #TODO : port this to TaylorTangent over composite structures
92
65
function shuffle_up(r::CompositeBundle{N}) where {N}
93
66
a, b = r.tup
94
67
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
@@ -102,6 +75,7 @@ function shuffle_up(r::CompositeBundle{N}) where {N}
102
75
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
103
76
end
104
77
end
78
+ ==#
105
79
106
80
function shuffle_up (r:: UniformBundle{N, B, U} ) where {N, B, U}
107
81
(a, b) = primal (r)
198
172
map (y-> lifted_getfield (y, s), x. tangent. coeffs))
199
173
end
200
174
201
- @Base . constprop :aggressive function (:: ∂☆{N})(:: ATB{N, typeof(getfield)} , x:: CompositeBundle{N} , s:: AbstractTangentBundle{N, Int} ) where {N}
202
- x. tup[primal (s)]
203
- end
204
-
205
- @Base . constprop :aggressive function (:: ∂☆{N})(:: ATB{N, typeof(getfield)} , x:: CompositeBundle{N, B} , s:: AbstractTangentBundle{N, Symbol} ) where {N, B}
206
- x. tup[Base. fieldindex (B, primal (s))]
207
- end
208
175
209
176
@Base . constprop :aggressive function (:: ∂☆{N})(f:: ATB{N, typeof(getfield)} , x:: UniformBundle{N, <:Any, U} , s:: AbstractTangentBundle{N} ) where {N, U}
210
177
UniformBundle {N,<:Any,U} (getfield (primal (x), primal (s)), x. tangent. val)
@@ -223,9 +190,12 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}}
223
190
end
224
191
(f:: FwdMap{N} )(args:: AbstractTangentBundle{N} ...) where {N} = ∂☆ {N} ()(f. f, args... )
225
192
193
+ #= =
194
+ # TODO port this to TaylorBundle over composite structure
226
195
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N}
227
196
∂vararg{N}()(map(FwdMap(f), tup.tup)...)
228
197
end
198
+ ==#
229
199
230
200
function (:: ∂☆{N})(:: ZeroBundle{N, typeof(map)} , f:: ATB{N} , args:: ATB{N, <:AbstractArray} ...) where {N}
231
201
# TODO : This could do an inplace map! to avoid the extra rebundling
@@ -267,23 +237,28 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate
267
237
Core. _apply_iterate (FwdIterate (iterate), this, (f,), args... )
268
238
end
269
239
240
+ #= =
241
+ #TODO : port this to TaylorTangent over composite structures
270
242
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N}
271
243
r = iterate(t.tup)
272
244
r === nothing && return ZeroBundle{N}(nothing)
273
245
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
274
246
end
275
247
248
+ #TODO : port this to TaylorTangent over composite structures
276
249
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
277
250
r = iterate(t.tup, primal(a), map(primal, args)...)
278
251
r === nothing && return ZeroBundle{N}(nothing)
279
252
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
280
253
end
281
254
255
+ #TODO : port this to TaylorTangent over composite structures
282
256
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N}
283
257
r = Base.indexed_iterate(t.tup, primal(i))
284
258
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
285
259
end
286
260
261
+ #TODO : port this to TaylorTangent over composite structures
287
262
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
288
263
r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...)
289
264
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
@@ -293,10 +268,11 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::Tan
293
268
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1))
294
269
end
295
270
296
-
271
+ # TODO : port this to TaylorTangent over composite structures
297
272
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N}
298
273
t.tup[primal(i)]
299
274
end
275
+ ==#
300
276
301
277
function (this: :∂☆ {N})(:: ZeroBundle{N, typeof(typeof)} , x:: ATB{N} ) where {N}
302
278
DNEBundle {N} (typeof (primal (x)))
0 commit comments