Skip to content

Commit 121ec20

Browse files
committed
port remaining ∂☆ overloads to TaylorTangent
1 parent bfd99c2 commit 121ec20

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

src/stage1/forward.jl

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -238,42 +238,38 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate
238238
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
239239
end
240240

241-
#==
242-
#TODO: port this to TaylorTangent over composite structures
243-
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N}
244-
r = iterate(t.tup)
241+
242+
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}) where {N}
243+
r = iterate(destructure(t))
245244
r === nothing && return ZeroBundle{N}(nothing)
246245
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
247246
end
248247

249-
#TODO: port this to TaylorTangent over composite structures
250-
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
251-
r = iterate(t.tup, primal(a), map(primal, args)...)
248+
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::TaylorBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
249+
r = iterate(destructure(t), primal(a), map(primal, args)...)
252250
r === nothing && return ZeroBundle{N}(nothing)
253251
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
254252
end
255253

256-
#TODO: port this to TaylorTangent over composite structures
257-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N}
258-
r = Base.indexed_iterate(t.tup, primal(i))
254+
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}) where {N}
255+
r = Base.indexed_iterate(destructure(t), primal(i))
259256
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
260257
end
261258

262-
#TODO: port this to TaylorTangent over composite structures
263-
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
264-
r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...)
259+
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TaylorBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
260+
r = Base.indexed_iterate(destructure(t), primal(i), primal(st1), map(primal, st)...)
265261
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
266262
end
267263

268264
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::TangentBundle{N, <:Tuple}, i::ATB{N}, st::ATB{N}...) where {N}
269265
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1))
270266
end
271267

272-
#TODO: port this to TaylorTangent over composite structures
273-
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N}
274-
t.tup[primal(i)]
268+
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::TaylorBundle{N, <:Tuple}, i::ZeroBundle) where {N}
269+
field_ind = primal(i)
270+
the_partials = ntuple(order_ind->partial(t, order_ind)[field_ind], N)
271+
TaylorBundle{N}(primal(t)[field_ind], the_partials)
275272
end
276-
==#
277273

278274
function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}
279275
DNEBundle{N}(typeof(primal(x)))

0 commit comments

Comments
 (0)