Skip to content

Commit ca2b47b

Browse files
authored
Merge pull request #420 from mcabbott/cumprod
Add rule for `cumprod`
2 parents 7caf869 + 841b802 commit ca2b47b

File tree

3 files changed

+132
-1
lines changed

3 files changed

+132
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.10.0"
3+
version = "1.11.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/mapreduce.jl

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,96 @@ function ∇prod_one_zero!(dx, x, dy::Number=1) # Assumes exactly one x is zero
218218
dx[i_zero] += p_rest * dy
219219
return
220220
end
221+
222+
#####
223+
##### `cumprod`
224+
#####
225+
226+
function rrule(::typeof(cumprod), x::AbstractVector{<:Real}; dims::Integer=1)
227+
y = cumprod(x; dims=dims) # does nothing unless dims == 1
228+
project_x = ProjectTo(x)
229+
function cumprod_pullback_1(dy_raw)
230+
dy = unthunk(dy_raw)
231+
dx_thunk = InplaceableThunk(
232+
dx -> if dims == 1
233+
∇cumprod!(dx, x, dy, y)
234+
else
235+
dx .+= dy
236+
end
237+
,
238+
@thunk project_x(if dims == 1
239+
∇cumprod(x, dy, y)
240+
else
241+
dy
242+
end)
243+
)
244+
return (NoTangent(), dx_thunk)
245+
end
246+
return y, cumprod_pullback_1
247+
end
248+
249+
function rrule(::typeof(cumprod), x::AbstractArray{<:Real}; dims::Integer)
250+
y = cumprod(x; dims=dims)
251+
project_x = ProjectTo(x)
252+
function cumprod_pullback_2(dy_raw)
253+
dy = unthunk(dy_raw)
254+
dx_thunk = InplaceableThunk(
255+
dx -> if dims <= ndims(x)
256+
vald = Val(Int(dims))
257+
∇cumprod_dim!(dx, vald, x, dy, y)
258+
else
259+
dx .+= dy
260+
end
261+
,
262+
@thunk project_x(if dims <= ndims(x)
263+
vald = Val(Int(dims))
264+
∇cumprod_dim(vald, x, dy, y)
265+
else
266+
dy
267+
end)
268+
)
269+
return (NoTangent(), dx_thunk)
270+
end
271+
return y, cumprod_pullback_2
272+
end
273+
274+
function ∇cumprod_dim(vald::Val{dim}, x::AbstractArray, dy=fill!(zero(x),1), y=cumprod(x; dims=dim)) where {dim}
275+
T = promote_type(eltype(x), eltype(dy))
276+
dx = fill!(similar(x, T, axes(x)), zero(T))
277+
∇cumprod_dim!(dx, vald, x, dy, y)
278+
return dx
279+
end
280+
281+
@inline function ∇cumprod_dim!(dx::AbstractArray, ::Val{dim}, x::AbstractArray, dy, y) where {dim}
282+
iters = ntuple(k -> k==dim ? Ref(:) : axes(x,k), ndims(x))
283+
for ind in Iterators.product(iters...)
284+
@views ∇cumprod!(dx[ind...], x[ind...], dy[ind...], y[ind...])
285+
end
286+
return dx
287+
end
288+
289+
function ∇cumprod(x::AbstractVector, dy=one(x), y=cumprod(x))
290+
T = promote_type(eltype(x), eltype(dy)) # really needs to allow dy * y / x
291+
dx = fill!(similar(x, T, axes(x)), zero(T)) # axes(x) makes MArray on StaticArrays, Array for structured matrices
292+
∇cumprod!(dx, x, dy, y)
293+
return dx
294+
end
295+
296+
@inline function ∇cumprod!(dx::AbstractVector, x::AbstractVector, dy, y)
297+
lo, hi = firstindex(x), lastindex(x)
298+
z = something(findfirst(iszero, x), hi+1)
299+
acc = zero(eltype(dy))
300+
@inbounds for k in z-1:-1:lo
301+
acc += y[k] * dy[k]
302+
dx[k] += acc / x[k]
303+
end
304+
@inbounds if z != hi+1
305+
yk = z==1 ? one(eltype(y)) : y[z-1] # will be prod(x[j] for j=1:k if j!=z)
306+
dx[z] += yk * dy[z]
307+
for k in (z+1):hi
308+
yk *= x[k]
309+
dx[z] += yk * dy[k]
310+
end
311+
end
312+
return dx
313+
end

test/rulesets/Base/mapreduce.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,41 @@
154154
end
155155
end # prod
156156
end
157+
158+
@testset "Accumulations" begin
159+
@testset "cumprod" begin
160+
v = round.(10 .* randn(9), sigdigits=3)
161+
test_rrule(cumprod, v)
162+
v[3] = 0
163+
test_rrule(cumprod, v)
164+
v[6] = 0
165+
test_rrule(cumprod, v)
166+
167+
@testset "higher dimensions, dims=$dims" for dims in (1,2,3)
168+
m = round.(10 .* randn(4,5), sigdigits=3)
169+
test_rrule(cumprod, m; fkwargs=(;dims=dims), atol=0.1)
170+
m[2,2] = 0
171+
m[2,4] = 0
172+
test_rrule(cumprod, m; fkwargs=(;dims=dims))
173+
174+
t = round.(10 .* randn(3,3,3), sigdigits=3)
175+
test_rrule(cumprod, t; fkwargs=(;dims=dims))
176+
t[2,2,2] = 0
177+
t[2,3,3] = 0
178+
test_rrule(cumprod, t; fkwargs=(;dims=dims))
179+
end
180+
181+
@testset "types" begin
182+
back = rrule(cumprod, [1, 2, 3])[2] # rule allows integer input, but test_rrule does not
183+
@test unthunk(back(fill(0.5, 3))[2]) == [9/2, 2, 1]
184+
185+
back = rrule(cumprod, PermutedDimsArray([1 2; 3 4], (2,1)); dims=1)[2]
186+
@test unthunk(back(ones(Float32, 2,2))[2]) == [3 5; 1 3]
187+
188+
@test_throws Exception cumprod(Symmetric([1 2; 3 4]), dims=1) # forward pass fails, so can't test gradient
189+
190+
back = rrule(cumprod, Diagonal([1, 2]); dims=1)[2]
191+
@test unthunk(back(fill(0.5, 2, 2))[2]) [1/2 0; 0 0] # ProjectTo'd to Diagonal now
192+
end
193+
end
194+
end

0 commit comments

Comments
 (0)