Skip to content

Commit cc4952c

Browse files
author
Miha Zgubic
committed
Merge branch 'main' into mz/nograd
2 parents 31f5e7b + e9d7d0a commit cc4952c

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

src/rulesets/Base/indexing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim}
9393
end
9494
return ProjectTo(x)(dx)
9595
end
96+
∇eachslice(dys::AbstractZero, x::AbstractArray, vd::Val{dim}) where {dim} = dys
9697

9798
_zero_fill!(dx::AbstractArray{<:Number}) = fill!(dx, zero(eltype(dx)))
9899
_zero_fill!(dx::AbstractArray) = map!(zero, dx, dx)

src/rulesets/Base/mapreduce.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,25 +186,28 @@ end
186186
#####
187187

188188
function frule((_, xdot), ::typeof(cumsum), x::AbstractArray; dims::Integer)
189-
return cumsum(x; dims=dims), cumsum(xdot; dims=dims)
189+
return cumsum(x; dims), cumsum(xdot; dims)
190190
end
191191
frule(tang, ::typeof(cumsum), x::AbstractVector) = frule(tang, cumsum, x; dims=1)
192192

193193
function frule((_, ydot, xdot), ::typeof(cumsum!), y::AbstractArray, x::AbstractArray; dims::Integer)
194-
return cumsum!(y, x; dims=dims), cumsum!(ydot, xdot; dims=dims)
194+
return cumsum!(y, x; dims), cumsum!(ydot, xdot; dims)
195195
end
196196
frule(t, ::typeof(cumsum!), y::AbstractVector, x::AbstractVector) = frule(t, cumsum!, y, x; dims=1)
197197

198-
function rrule(::typeof(cumsum), x::AbstractArray; dims::Integer)
198+
function rrule(::typeof(cumsum), x::AbstractArray{T,N}; dims::Integer) where {T,N}
199199
project = ProjectTo(x)
200200
function cumsum_pullback(dy)
201+
if dims > N # trivial case, for which reverse fails
202+
return (NoTangent(), project(unthunk(dy)))
203+
end
201204
step1 = reverse(unthunk(dy); dims=dims)
202-
if ChainRulesCore.is_inplaceable_destination(step1) && VERSION >= v"1.6"
203-
step2 = cumsum!(step1, step1; dims=dims)
204-
step3 = reverse!(step2; dims=dims)
205+
if ChainRulesCore.is_inplaceable_destination(step1)
206+
step2 = cumsum!(step1, step1; dims)
207+
step3 = reverse!(step2; dims)
205208
else
206-
step2 = cumsum(step1; dims=dims)
207-
step3 = reverse(step2; dims=dims)
209+
step2 = cumsum(step1; dims)
210+
step3 = reverse(step2; dims)
208211
end
209212
return (NoTangent(), project(step3))
210213
end

test/rulesets/Base/mapreduce.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ end
277277
test_rrule(cumsum, v)
278278
test_rrule(cumsum, v; fkwargs=(;dims=1))
279279
test_rrule(cumsum, m; fkwargs=(;dims=2))
280+
test_rrule(cumsum, m; fkwargs=(;dims=3)) # trivial
280281
end
281282
@testset "cumprod" begin
282283
v = round.(10 .* randn(9), sigdigits=3)

0 commit comments

Comments
 (0)