Skip to content

Commit 7dc0e85

Browse files
committed
cumprod, take 1
1 parent 649bfbb commit 7dc0e85

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,93 @@ function ∇prod_one_zero!(dx, x, dy::Number=1) # Assumes exactly one x is zero
218218
dx[i_zero] += p_rest * dy
219219
return
220220
end
221+
222+
#####
223+
##### `cumprod`
224+
#####
225+
226+
function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims=1)
227+
y = cumprod(x; dims=dims) # does nothing unless dims == 1
228+
function cumprod_pullback_1(dy)
229+
dx_thunk = InplaceableThunk(
230+
@thunk if dims == 1
231+
∇cumprod(x, dy, y)
232+
else
233+
dy
234+
end
235+
,
236+
dx -> if dims == 1
237+
∇cumprod!(dx, x, dy, y)
238+
else
239+
dx .+= dy
240+
end
241+
)
242+
return (NO_FIELDS, dx_thunk)
243+
end
244+
return y, cumprod_pullback_1
245+
end
246+
247+
function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims)
248+
y = cumprod(x; dims=dims)
249+
@assert dims isa Integer
250+
# vald = Val(dims)
251+
function cumprod_pullback_2(dy)
252+
dx_thunk = InplaceableThunk(
253+
@thunk if dims <= ndims(x)
254+
∇cumprod_dim(dims, x, dy, y)
255+
else
256+
dy
257+
end
258+
,
259+
dx -> if dims <= ndims(x)
260+
∇cumprod_dim!(dx, dims, x, dy, y)
261+
else
262+
dx .+= dy
263+
end
264+
)
265+
return (NO_FIELDS, dx_thunk)
266+
end
267+
return y, cumprod_pullback_2
268+
end
269+
270+
function ∇cumprod_dim(dim::Integer, x::AbstractArray, dy=fill!(zero(x),1), y=cumprod(x; dims=dim))
271+
T = promote_type(eltype(x), eltype(dy))
272+
dx = fill!(similar(x, T, axes(x)), zero(T))
273+
∇cumprod_dim!(dx, dim, x, dy, y)
274+
return dx
275+
end
276+
277+
function ∇cumprod_dim!(dx::AbstractArray, dim::Integer, x::AbstractArray, dy, y)
278+
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x)) # type instability!
279+
for ind in Iterators.product(iters...)
280+
@views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...])
281+
end
282+
return dx
283+
end
284+
285+
function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
286+
T = promote_type(eltype(x), eltype(dy))
287+
dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
288+
∇cumprod!(dx, x, dy, y)
289+
return dx
290+
end
291+
292+
function ∇cumprod!(dx::AbstractVector, x::AbstractVector, dy, y)
293+
lo, hi = firstindex(x), lastindex(x)
294+
z = something(findfirst(iszero, x), hi+1)
295+
@inbounds for i in lo:z-1
296+
ixi = 1/x[i]
297+
for k in i:z-1
298+
dx[i] += y[k] * dy[k] * ixi
299+
end
300+
end
301+
@inbounds if z != hi+1
302+
yk = z==1 ? one(eltype(y)) : y[z-1] # will be prod(x[j] for j=1:k if j!=z)
303+
dx[z] += yk * dy[z]
304+
for k in (z+1):hi
305+
yk *= x[k]
306+
dx[z] += yk * dy[k]
307+
end
308+
end
309+
return dx
310+
end

test/rulesets/Base/mapreduce.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,38 @@
154154
end
155155
end # prod
156156
end
157+
158+
159+
160+
@testset "cumprod" begin
161+
v = randn(9)
162+
test_rrule(cumprod, v)
163+
v[3] = 0
164+
test_rrule(cumprod, v)
165+
v[6] = 0
166+
test_rrule(cumprod, v)
167+
168+
@testset "higher dimensions, dims=$dims" for dims in (1,2,3)
169+
m = rand(4,5)
170+
test_rrule(cumprod, m; fkwargs=(;dims=dims))
171+
m[2,2] = 0
172+
m[2,4] = 0
173+
test_rrule(cumprod, m; fkwargs=(;dims=dims))
174+
175+
t = randn(3,3,3)
176+
test_rrule(cumprod, x; fkwargs=(;dims=dims))
177+
end
178+
179+
@testset "types" begin
180+
back = unthunk(rrule(cumprod, [1, 2, 3])[2])
181+
@test unthunk(back(fill(0.5, 3))[2]) == [9/2, 2, 1]
182+
183+
back = unthunk(rrule(cumprod, PermutedDimsArray([1 2; 3 4], (2,1)); dims=1)[2])
184+
@test unthunk(back(ones(Float32, 2,2))[2]) == [3 5; 1 3]
185+
186+
@test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1)
187+
188+
back = unthunk(rrule(cumprod, Diagonal([1, 2]); dims=1)[2])
189+
@test unthunk(back(fill(0.5, 2, 2))[2]) [1/2 3/2; 1/2 0]
190+
end
191+
end

0 commit comments

Comments
 (0)