Skip to content

Commit 50383b6

Browse files
committed
Port over shuffle_up
1 parent f4f996d commit 50383b6

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

src/stage1/forward.jl

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -47,35 +47,39 @@ function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2}
4747
end
4848
end
4949

50-
#==
5150
function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
5251
primal(b) === a[TaylorTangentIndex(1)] || return false
5352
return all(1:(N-1)) do i
5453
b[TaylorTangentIndex(i)] === a[TaylorTangentIndex(i+1)]
5554
end
5655
end
5756

58-
# Check whether the tangent bundle element is taylor-like
59-
isswifty(::TaylorBundle) = true
60-
isswifty(::UniformBundle) = true
61-
isswifty(b::CompositeBundle) = all(isswifty, b.tup)
62-
isswifty(::Any) = false
63-
64-
#TODO: port this to TaylorTangent over composite structures
65-
function shuffle_up(r::CompositeBundle{N}) where {N}
66-
a, b = r.tup
67-
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
68-
return TaylorBundle{N+1}(primal(a),
69-
ntuple(i->i == N+1 ?
70-
b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)],
71-
N+1))
57+
function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
58+
partial(r, 1)[1] = primal(r)[2] || return false
59+
return all(1:N-1) do ii
60+
partial(r, i+1)[1] == partial(r, i)[2]
61+
end
62+
end
63+
function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
64+
the_primal = primal(r)[1]
65+
if taylor_compatible(r)
66+
the_partials = ntuple(N+1) do i
67+
if ii <= N
68+
partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
69+
else # ii = N+1
70+
partial(r, i-1)[2]
71+
end
72+
end
73+
return TaylorBundle{N+1}(the_primal, the_partials)
7274
else
73-
return TangentBundle{N+1}(r.tup[1].primal,
74-
(r.tup[1].tangent.partials..., primal(b),
75-
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
75+
#XXX: am dubious of the correctness of this
76+
a_partials = ntuple(i->partial(r, ii)[1], N)
77+
b_partials = ntuple(i->partial(r, ii)[2], N)
78+
the_partials = (a_partials..., primal_b, b_partials...)
79+
return TangentBundle{N+1}(the_primal, the_partials)
7680
end
7781
end
78-
==#
82+
7983

8084
function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
8185
(a, b) = primal(r)

0 commit comments

Comments
 (0)