@@ -186,25 +186,28 @@ end
186
186
# ####
187
187
188
188
function frule ((_, xdot), :: typeof (cumsum), x:: AbstractArray ; dims:: Integer )
189
- return cumsum (x; dims= dims ), cumsum (xdot; dims = dims)
189
+ return cumsum (x; dims), cumsum (xdot; dims)
190
190
end
191
191
frule (tang, :: typeof (cumsum), x:: AbstractVector ) = frule (tang, cumsum, x; dims= 1 )
192
192
193
193
function frule ((_, ydot, xdot), :: typeof (cumsum!), y:: AbstractArray , x:: AbstractArray ; dims:: Integer )
194
- return cumsum! (y, x; dims= dims ), cumsum! (ydot, xdot; dims = dims)
194
+ return cumsum! (y, x; dims), cumsum! (ydot, xdot; dims)
195
195
end
196
196
frule (t, :: typeof (cumsum!), y:: AbstractVector , x:: AbstractVector ) = frule (t, cumsum!, y, x; dims= 1 )
197
197
198
- function rrule (:: typeof (cumsum), x:: AbstractArray ; dims:: Integer )
198
+ function rrule (:: typeof (cumsum), x:: AbstractArray{T,N} ; dims:: Integer ) where {T,N}
199
199
project = ProjectTo (x)
200
200
function cumsum_pullback (dy)
201
+ if dims > N # trivial case, for which reverse fails
202
+ return (NoTangent (), project (unthunk (dy)))
203
+ end
201
204
step1 = reverse (unthunk (dy); dims= dims)
202
- if ChainRulesCore. is_inplaceable_destination (step1) && VERSION >= v " 1.6 "
203
- step2 = cumsum! (step1, step1; dims= dims )
204
- step3 = reverse! (step2; dims= dims )
205
+ if ChainRulesCore. is_inplaceable_destination (step1)
206
+ step2 = cumsum! (step1, step1; dims)
207
+ step3 = reverse! (step2; dims)
205
208
else
206
- step2 = cumsum (step1; dims= dims )
207
- step3 = reverse (step2; dims= dims )
209
+ step2 = cumsum (step1; dims)
210
+ step3 = reverse (step2; dims)
208
211
end
209
212
return (NoTangent (), project (step3))
210
213
end
0 commit comments