@@ -20,7 +20,7 @@ primal(z::ZeroTangent) = ZeroTangent()
20
20
first_partial (x) = partial (x, 1 )
21
21
22
22
shuffle_down (b:: UniformBundle{N, B, U} ) where {N, B, U} =
23
- UniformBundle {minus1(N) , <:Any} (UniformBundle {1, B} (b. primal, b. tangent. val),
23
+ UniformBundle {N-1 , <:Any} (UniformBundle {1, B} (b. primal, b. tangent. val),
24
24
UniformBundle {1, U} (b. tangent. val, b. tangent. val))
25
25
26
26
function shuffle_down (b:: ExplicitTangentBundle{N, B} ) where {N, B}
@@ -30,7 +30,7 @@ function shuffle_down(b::ExplicitTangentBundle{N, B}) where {N, B}
30
30
end
31
31
ExplicitTangentBundle {N-1} (
32
32
ExplicitTangentBundle {1} (b. primal, (partial (b, 1 ),)),
33
- ntuple (_sdown, 2 ^ (N- 1 )- 1 ))
33
+ ntuple (_sdown, 1 << (N- 1 )- 1 ))
34
34
end
35
35
36
36
function shuffle_down (b:: TaylorBundle{N, B} ) where {N, B}
@@ -86,7 +86,7 @@ function shuffle_up(r::CompositeBundle{N}) where {N}
86
86
else
87
87
return TangentBundle {N+1} (r. tup[1 ]. primal,
88
88
(r. tup[1 ]. tangent. partials... , primal (b),
89
- ntuple (i-> partial (b,i), 2 ^ (N+ 1 )- 1 )... ))
89
+ ntuple (i-> partial (b,i), 1 << (N+ 1 )- 1 )... ))
90
90
end
91
91
end
92
92
@@ -131,10 +131,10 @@ function ChainRulesCore.frule_via_ad(::DiffractorRuleConfig, partials, args...)
131
131
end
132
132
133
133
function (:: ∂☆shuffle{N})(args:: AbstractTangentBundle{N} ...) where {N}
134
- ∂☆p = ∂☆ {minus1(N) } ()
134
+ ∂☆p = ∂☆ {N-1 } ()
135
135
downargs = map (shuffle_down, args)
136
- tupargs = ∂vararg {minus1(N) } ()(map (first_partial, downargs)... )
137
- ∂☆p (ZeroBundle {minus1(N) } (frule), #= ZeroBundle{minus1(N) }(DiffractorRuleConfig()), =# tupargs, map (primal, downargs)... )
136
+ tupargs = ∂vararg {N-1 } ()(map (first_partial, downargs)... )
137
+ ∂☆p (ZeroBundle {N-1 } (frule), #= ZeroBundle{N-1 }(DiffractorRuleConfig()), =# tupargs, map (primal, downargs)... )
138
138
end
139
139
140
140
function (:: ∂☆internal{N})(args:: AbstractTangentBundle{N} ...) where {N}
0 commit comments