Skip to content

Commit eb7f7f6

Browse files
oscardssmithstaticfloat
authored andcommitted
update forward.jl
1 parent 5871a57 commit eb7f7f6

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/stage1/forward.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ primal(z::ZeroTangent) = ZeroTangent()
2020
first_partial(x) = partial(x, 1)
2121

2222
shuffle_down(b::UniformBundle{N, B, U}) where {N, B, U} =
23-
UniformBundle{minus1(N), <:Any}(UniformBundle{1, B}(b.primal, b.tangent.val),
23+
UniformBundle{N-1, <:Any}(UniformBundle{1, B}(b.primal, b.tangent.val),
2424
UniformBundle{1, U}(b.tangent.val, b.tangent.val))
2525

2626
function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
@@ -30,7 +30,7 @@ function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
3030
end
3131
ExplicitTangentBundle{N-1}(
3232
ExplicitTangentBundle{1}(b.primal, (partial(b, 1),)),
33-
ntuple(_sdown, 2^(N-1)-1))
33+
ntuple(_sdown, 1<<(N-1)-1))
3434
end
3535

3636
function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
@@ -86,7 +86,7 @@ function shuffle_up(r::CompositeBundle{N}) where {N}
8686
else
8787
return TangentBundle{N+1}(r.tup[1].primal,
8888
(r.tup[1].tangent.partials..., primal(b),
89-
ntuple(i->partial(b,i), 2^(N+1)-1)...))
89+
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
9090
end
9191
end
9292

@@ -131,10 +131,10 @@ function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
131131
end
132132

133133
function (::∂☆shuffle{N})(args::AbstractTangentBundle{N}...) where {N}
134-
∂☆p = ∂☆{minus1(N)}()
134+
∂☆p = ∂☆{N-1}()
135135
downargs = map(shuffle_down, args)
136-
tupargs = ∂vararg{minus1(N)}()(map(first_partial, downargs)...)
137-
∂☆p(ZeroBundle{minus1(N)}(frule), #= ZeroBundle{minus1(N)}(DiffractorRuleConfig()), =# tupargs, map(primal, downargs)...)
136+
tupargs = ∂vararg{N-1}()(map(first_partial, downargs)...)
137+
∂☆p(ZeroBundle{N-1}(frule), #= ZeroBundle{N-1}(DiffractorRuleConfig()), =# tupargs, map(primal, downargs)...)
138138
end
139139

140140
function (::∂☆internal{N})(args::AbstractTangentBundle{N}...) where {N}

0 commit comments

Comments
 (0)