Skip to content

Commit 4b4c347

Browse files
committed
fix a type instability
1 parent 7dc0e85 commit 4b4c347

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -247,17 +247,17 @@ end
247247
function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims)
248248
y = cumprod(x; dims=dims)
249249
@assert dims isa Integer
250-
# vald = Val(dims)
250+
vald = Val(Int(dims)) # else ∇cumprod_dim! will be type unstable
251251
function cumprod_pullback_2(dy)
252252
dx_thunk = InplaceableThunk(
253253
@thunk if dims <= ndims(x)
254-
∇cumprod_dim(dims, x, dy, y)
254+
∇cumprod_dim(vald, x, dy, y)
255255
else
256256
dy
257257
end
258258
,
259259
dx -> if dims <= ndims(x)
260-
∇cumprod_dim!(dx, dims, x, dy, y)
260+
∇cumprod_dim!(dx, vald, x, dy, y)
261261
else
262262
dx .+= dy
263263
end
@@ -267,21 +267,34 @@ function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims)
267267
return y, cumprod_pullback_2
268268
end
269269

270-
function ∇cumprod_dim(dim::Integer, x::AbstractArray, dy=fill!(zero(x),1), y=cumprod(x; dims=dim))
270+
function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y=cumprod(x; dims=dim)) where {dim}
271271
T = promote_type(eltype(x), eltype(dy))
272272
dx = fill!(similar(x, T, axes(x)), zero(T))
273-
∇cumprod_dim!(dx, dim, x, dy, y)
273+
∇cumprod_dim!(dx, vald, x, dy, y)
274274
return dx
275275
end
276276

277-
function ∇cumprod_dim!(dx::AbstractArray, dim::Integer, x::AbstractArray, dy, y)
278-
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) # type instability!
277+
function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim}
278+
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x))
279279
for ind in Iterators.product(iters...)
280280
@views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...])
281281
end
282282
return dx
283283
end
284284

285+
#=
286+
287+
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1]
288+
86.333 μs (2007 allocations: 67.54 KiB)
289+
290+
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] # with 1 hard-coded
291+
5.417 μs (6 allocations: 15.95 KiB)
292+
293+
julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] # with Val(dim)
294+
5.423 μs (6 allocations: 15.95 KiB)
295+
296+
=#
297+
285298
function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
286299
T = promote_type(eltype(x), eltype(dy))
287300
dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices

0 commit comments

Comments
 (0)