@@ -276,13 +276,28 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y
276
276
end
277
277
278
278
@inline function ∇cumprod_dim! (dx:: AbstractArray , :: Val{dim} , x:: AbstractArray , dy, y) where {dim}
279
- iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x))
280
- for ind in Iterators. product (iters... )
281
- @views ∇cumprod! (dx[ind... ], x[ind... ], dy[ind... ], y[ind... ])
279
+ if any (iszero, x)
280
+ iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x))
281
+ for ind in Iterators. product (iters... )
282
+ @views ∇cumprod! (dx[ind... ], x[ind... ], dy[ind... ], y[ind... ])
283
+ end
284
+ else
285
+ step1 = y .* dy # _rscale!!(y, dy) # is it safe to mutate y?
286
+ step2 = _reverse!! (_cumsum!! (_reverse!! (step1, dim), dim), dim)
287
+ dx .+ = step2 ./ x
282
288
end
283
289
return dx
284
290
end
285
291
292
+ # _rscale!!(A, β) = A .* β
293
+ # _rscale!!(A::StridedArray, β) = A .*= β
294
+
295
+ _reverse!! (x, dims= 1 ) = reverse (x; dims= dims)
296
+ _reverse!! (x:: StridedArray , dims= 1 ) = reverse! (x; dims= dims)
297
+
298
+ _cumsum!! (x, dims= 1 ) = cumsum (x; dims= dims)
299
+ _cumsum!! (x:: StridedArray , dims= 1 ) = cumsum! (x, x; dims= dims)
300
+
286
301
function ∇cumprod (x:: AbstractVector , dy= one (x), y= cumprod (x))
287
302
T = promote_type (eltype (x), eltype (dy))
288
303
dx = fill! (similar (x, T, axes (x)), zero (T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
0 commit comments