@@ -247,17 +247,17 @@ end
247
247
function rrule (:: typeof (cumprod), x:: AbstractArray{<:Real} ; dims)
248
248
y = cumprod (x; dims= dims)
249
249
@assert dims isa Integer
250
- # vald = Val(dims)
250
+ vald = Val (Int ( dims)) # else ∇cumprod_dim! will be type unstable
251
251
function cumprod_pullback_2 (dy)
252
252
dx_thunk = InplaceableThunk (
253
253
@thunk if dims <= ndims (x)
254
- ∇cumprod_dim (dims , x, dy, y)
254
+ ∇cumprod_dim (vald , x, dy, y)
255
255
else
256
256
dy
257
257
end
258
258
,
259
259
dx -> if dims <= ndims (x)
260
- ∇cumprod_dim! (dx, dims , x, dy, y)
260
+ ∇cumprod_dim! (dx, vald , x, dy, y)
261
261
else
262
262
dx .+ = dy
263
263
end
@@ -267,21 +267,34 @@ function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims)
267
267
return y, cumprod_pullback_2
268
268
end
269
269
270
- function ∇cumprod_dim (dim :: Integer , x:: AbstractArray , dy= fill! (zero (x),1 ), y= cumprod (x; dims= dim))
270
+ function ∇cumprod_dim (vald :: Val{dim} , x:: AbstractArray , dy= fill! (zero (x),1 ), y= cumprod (x; dims= dim)) where {dim}
271
271
T = promote_type (eltype (x), eltype (dy))
272
272
dx = fill! (similar (x, T, axes (x)), zero (T))
273
- ∇cumprod_dim! (dx, dim , x, dy, y)
273
+ ∇cumprod_dim! (dx, vald , x, dy, y)
274
274
return dx
275
275
end
276
276
277
- function ∇cumprod_dim! (dx:: AbstractArray , dim :: Integer , x:: AbstractArray , dy, y)
278
- iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x)) # type instability!
277
+ function ∇cumprod_dim! (dx:: AbstractArray , :: Val{dim} , x:: AbstractArray , dy, y) where {dim}
278
+ iters = ntuple (k -> k== dim ? Ref (:) : axes (x,k), ndims (x))
279
279
for ind in Iterators. product (iters... )
280
280
@views ∇cumprod! (dx[ind... ], x[ind... ], dy[ind... ], y[ind... ])
281
281
end
282
282
return dx
283
283
end
284
284
285
+ #=
286
+
287
+ julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1]
288
+ 86.333 μs (2007 allocations: 67.54 KiB)
289
+
290
+ julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] # with 1 hard-coded
291
+ 5.417 μs (6 allocations: 15.95 KiB)
292
+
293
+ julia> @btime gradient(x -> sum(cumprod(x, dims=1)), $(rand(10,100)))[1] # with Val(dim)
294
+ 5.423 μs (6 allocations: 15.95 KiB)
295
+
296
+ =#
297
+
285
298
function ∇cumprod (x:: AbstractVector , dy= one (x), y= cumprod (x))
286
299
T = promote_type (eltype (x), eltype (dy))
287
300
dx = fill! (similar (x, T, axes (x)), zero (T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
0 commit comments