Skip to content

Commit 335d025

Browse files
mcabbottoxinabox
andauthored
Rules for cumsum (#573)
* cumsum, with variants, and mystery test failures * fixup * bang * digits Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent 8706f69 commit 335d025

File tree

3 files changed

+47
-1
lines changed

3 files changed

+47
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.20.1"
3+
version = "1.21"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/mapreduce.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,37 @@ for Config in (RuleConfig, RuleConfig{>:HasReverseMode})
153153
end
154154
end
155155

156+
#####
157+
##### `cumsum`
158+
#####
159+
160+
function frule((_, xdot), ::typeof(cumsum), x::AbstractArray; dims::Integer)
161+
return cumsum(x; dims=dims), cumsum(xdot; dims=dims)
162+
end
163+
frule(tang, ::typeof(cumsum), x::AbstractVector) = frule(tang, cumsum, x; dims=1)
164+
165+
function frule((_, ydot, xdot), ::typeof(cumsum!), y::AbstractArray, x::AbstractArray; dims::Integer)
166+
return cumsum!(y, x; dims=dims), cumsum!(ydot, xdot; dims=dims)
167+
end
168+
frule(t, ::typeof(cumsum!), y::AbstractVector, x::AbstractVector) = frule(t, cumsum!, y, x; dims=1)
169+
170+
function rrule(::typeof(cumsum), x::AbstractArray; dims::Integer)
171+
project = ProjectTo(x)
172+
function cumsum_pullback(dy)
173+
step1 = reverse(unthunk(dy); dims=dims)
174+
if ChainRulesCore.is_inplaceable_destination(step1) && VERSION >= v"1.6"
175+
step2 = cumsum!(step1, step1; dims=dims)
176+
step3 = reverse!(step2; dims=dims)
177+
else
178+
step2 = cumsum(step1; dims=dims)
179+
step3 = reverse(step2; dims=dims)
180+
end
181+
return (NoTangent(), project(step3))
182+
end
183+
return cumsum(x; dims=dims), cumsum_pullback
184+
end
185+
rrule(::typeof(cumsum), x::AbstractVector) = rrule(cumsum, x; dims=1)
186+
156187
#####
157188
##### `prod`
158189
#####

test/rulesets/Base/mapreduce.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,21 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
238238
end
239239

240240
@testset "Accumulations" begin
241+
@testset "cumsum" begin
242+
v = round.(10 .* randn(9), digits=3)
243+
m = round.(10 .* randn(4, 5), digits=3)
244+
245+
# Forward
246+
test_frule(cumsum, v)
247+
test_frule(cumsum, m; fkwargs=(;dims=1))
248+
test_frule(cumsum!, rand(9), v)
249+
test_frule(cumsum!, rand(4, 5), m; fkwargs=(;dims=1))
250+
251+
# Reverse
252+
test_rrule(cumsum, v)
253+
test_rrule(cumsum, v; fkwargs=(;dims=1))
254+
test_rrule(cumsum, m; fkwargs=(;dims=2))
255+
end
241256
@testset "cumprod" begin
242257
v = round.(10 .* randn(9), sigdigits=3)
243258
test_rrule(cumprod, v)

0 commit comments

Comments
 (0)