Skip to content

Commit 895d442

Browse files
committed
...after which, the fast path isn't faster anymore, so delete it.
1 parent 2f34b50 commit 895d442

File tree

1 file changed

+4
-23
lines changed

1 file changed

+4
-23
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -276,34 +276,15 @@ function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y
276276
end
277277

278278
@inline function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim}
279-
if any(iszero, x)
280-
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x))
281-
for ind in Iterators.product(iters...)
282-
@views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...])
283-
end
284-
else
285-
step1 = y .* dy # _rscale!!(y, dy) # is it safe to mutate y?
286-
step2 = _reverse!!(_cumsum!!(_reverse!!(step1, dim), dim), dim)
287-
dx .+= step2 ./ x
279+
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x))
280+
for ind in Iterators.product(iters...)
281+
@views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...])
288282
end
289283
return dx
290284
end
291285

292-
# _rscale!!(A, β) = A .* β
293-
# _rscale!!(A::StridedArray, β) = A .*= β
294-
295-
_reverse!!(x, dims=1) = reverse(x; dims=dims)
296-
if VERSION >= v"1.6"
297-
_reverse!!(x::StridedArray, dims=1) = reverse!(x; dims=dims)
298-
else
299-
_reverse!!(x::StridedVector, dims=1) = dims==1 ? reverse!(x) : x
300-
end
301-
302-
_cumsum!!(x, dims=1) = cumsum(x; dims=dims)
303-
_cumsum!!(x::StridedArray, dims=1) = cumsum!(x, x; dims=dims)
304-
305286
function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
306-
T = promote_type(eltype(x), eltype(dy))
287+
T = promote_type(eltype(x), eltype(dy)) # really needs to allow dy * y / x
307288
dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
308289
∇cumprod!(dx, x, dy, y)
309290
return dx

0 commit comments

Comments
 (0)