Skip to content

Commit da566cd

Browse files
committed
borrow fast path from Zygote 294
1 parent fac481c commit da566cd

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,28 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y
276276
end
277277

278278
@inline function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim}
279-
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x))
280-
for ind in Iterators.product(iters...)
281-
@views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...])
279+
if any(iszero, x)
280+
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x))
281+
for ind in Iterators.product(iters...)
282+
@views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...])
283+
end
284+
else
285+
step1 = y .* dy # _rscale!!(y, dy) # is it safe to mutate y?
286+
step2 = _reverse!!(_cumsum!!(_reverse!!(step1, dim), dim), dim)
287+
dx .+= step2 ./ x
282288
end
283289
return dx
284290
end
285291

292+
# _rscale!!(A, β) = A .* β
293+
# _rscale!!(A::StridedArray, β) = A .*= β
294+
295+
_reverse!!(x, dims=1) = reverse(x; dims=dims)
296+
_reverse!!(x::StridedArray, dims=1) = reverse!(x; dims=dims)
297+
298+
_cumsum!!(x, dims=1) = cumsum(x; dims=dims)
299+
_cumsum!!(x::StridedArray, dims=1) = cumsum!(x, x; dims=dims)
300+
286301
function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
287302
T = promote_type(eltype(x), eltype(dy))
288303
dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices

test/rulesets/Base/mapreduce.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,16 @@ end
165165

166166
@testset "higher dimensions, dims=$dims" for dims in (1,2,3)
167167
m = round.(10 .* randn(4,5), sigdigits=3)
168-
test_rrule(cumprod, m; fkwargs=(;dims=dims))
168+
test_rrule(cumprod, m; fkwargs=(;dims=dims), atol=0.1)
169169
m[2,2] = 0
170170
m[2,4] = 0
171171
test_rrule(cumprod, m; fkwargs=(;dims=dims))
172172

173173
t = round.(10 .* randn(3,3,3), sigdigits=3)
174174
test_rrule(cumprod, t; fkwargs=(;dims=dims))
175+
t[2,2,2] = 0
176+
t[2,3,3] = 0
177+
test_rrule(cumprod, t; fkwargs=(;dims=dims))
175178
end
176179

177180
@testset "types" begin

0 commit comments

Comments
 (0)