@@ -276,34 +276,15 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y
276
276
end
277
277
278
278
@inline function ∇cumprod_dim! (dx:: AbstractArray , :: Val{dim} , x:: AbstractArray , dy, y) where {dim}
279
- if any (iszero, x)
280
- iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x))
281
- for ind in Iterators. product (iters... )
282
- @views ∇cumprod! (dx[ind... ], x[ind... ], dy[ind... ], y[ind... ])
283
- end
284
- else
285
- step1 = y .* dy # _rscale!!(y, dy) # is it safe to mutate y?
286
- step2 = _reverse!! (_cumsum!! (_reverse!! (step1, dim), dim), dim)
287
- dx .+ = step2 ./ x
279
+ iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x))
280
+ for ind in Iterators. product (iters... )
281
+ @views ∇cumprod! (dx[ind... ], x[ind... ], dy[ind... ], y[ind... ])
288
282
end
289
283
return dx
290
284
end
291
285
292
- # _rscale!!(A, β) = A .* β
293
- # _rscale!!(A::StridedArray, β) = A .*= β
294
-
295
- _reverse!! (x, dims= 1 ) = reverse (x; dims= dims)
296
- if VERSION >= v " 1.6"
297
- _reverse!! (x:: StridedArray , dims= 1 ) = reverse! (x; dims= dims)
298
- else
299
- _reverse!! (x:: StridedVector , dims= 1 ) = dims== 1 ? reverse! (x) : x
300
- end
301
-
302
- _cumsum!! (x, dims= 1 ) = cumsum (x; dims= dims)
303
- _cumsum!! (x:: StridedArray , dims= 1 ) = cumsum! (x, x; dims= dims)
304
-
305
286
function ∇cumprod (x:: AbstractVector , dy= one (x), y= cumprod (x))
306
- T = promote_type (eltype (x), eltype (dy))
287
+ T = promote_type (eltype (x), eltype (dy)) # really needs to allow dy * y / x
307
288
dx = fill! (similar (x, T, axes (x)), zero (T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
308
289
∇cumprod! (dx, x, dy, y)
309
290
return dx
0 commit comments