Skip to content

Commit 3355ae7

Browse files
mcabbottoxinabox
andauthored
Rules for eachslice (#561)
* rules for eachslice * test only 1.6+ * move methods, use projection, allow 2nd derivatives * fixup * avoid limitations of ComposedFunction on 1.6 * fixup rebase, and bump version * change to at-assert * Apply 5 suggestions Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent d0b8d92 commit 3355ae7

File tree

3 files changed

+96
-1
lines changed

3 files changed

+96
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.24"
3+
version = "1.25"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/indexing.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,68 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...)
4848
end
4949

5050

51+
#####
52+
##### `eachslice` and friends
53+
#####
54+
55+
function rrule(::typeof(eachrow), x::AbstractVecOrMat)
56+
allrows(dy) = (NoTangent(), ∇eachslice(unthunk(dy), x, Val(1)))
57+
return collect(eachrow(x)), allrows
58+
end
59+
60+
function rrule(::typeof(eachcol), x::AbstractVecOrMat)
61+
allcols(dy) = (NoTangent(), ∇eachslice(unthunk(dy), x, Val(2)))
62+
return collect(eachcol(x)), allcols
63+
end
64+
65+
function rrule(::typeof(eachslice), x::AbstractArray; dims)
66+
y = collect(eachslice(x; dims=dims))
67+
@assert length(dims) == 1 """That's amazing, after many years JuliaLang/julia#32310
68+
actually landed. Sadly, the gradient rule for `eachslice` is unable to handle this
69+
case right now, please make an issue at https://github.com/JuliaDiff/ChainRules.jl"""
70+
dim = only(dims)
71+
allslices(dy) = (NoTangent(), ∇eachslice(unthunk(dy), x, Val(dim)))
72+
return y, allslices
73+
end
74+
75+
# Using Val(dim) here is worth a factor of 2 in this, on Julia 1.8-
76+
# @btime rrule(eachcol, $([1 2; 3 4]))[2]($([[10, 20], [30, 40]]))
77+
function ∇eachslice(dys_raw, x::AbstractArray, vd::Val{dim}) where {dim}
78+
dys = unthunk(dys_raw)
79+
i1 = findfirst(dy -> dy isa AbstractArray, dys)
80+
if i1 === nothing # all slices are Zero!
81+
return _zero_fill!(similar(x, float(eltype(x)), axes(x)))
82+
end
83+
T = promote_type(eltype(dys[i1]), eltype(x))
84+
# The whole point of this gradient is that we can allocate one `dx` array:
85+
dx = similar(x, T, axes(x))
86+
for i in axes(x, dim)
87+
slice = selectdim(dx, dim, i)
88+
if dys[i] isa AbstractZero
89+
_zero_fill!(slice) # Avoids this: copyto!([1,2,3], ZeroTangent()) == [0,2,3]
90+
else
91+
copyto!(slice, dys[i])
92+
end
93+
end
94+
return ProjectTo(x)(dx)
95+
end
96+
97+
_zero_fill!(dx::AbstractArray{<:Number}) = fill!(dx, zero(eltype(dx)))
98+
_zero_fill!(dx::AbstractArray) = map!(zero, dx, dx)
99+
100+
function rrule(::typeof(∇eachslice), dys, x, vd::Val)
101+
function ∇∇eachslice(dz_raw)
102+
dz = unthunk(dz_raw)
103+
# eachslice(dz; dims=_val(vd)) does not make @code_warntype happy
104+
iter = vd == Val(1) ? eachrow(dz) : vd == Val(2) ? eachcol(dz) : eachslice(dz; dims=_val(vd))
105+
return (NoTangent(), collect(iter), NoTangent(), NoTangent())
106+
end
107+
return ∇eachslice(dys, x, vd), ∇∇eachslice
108+
end
51109

110+
# These rules help with testing, and won't hurt:
111+
# They are correct as we always `collect` the primal result as we need that
112+
# information for the reverse pass
113+
ChainRules.rrule(::typeof(collecteachrow), x) = rrule(eachrow, x)
114+
ChainRules.rrule(::typeof(collecteachcol), x) = rrule(eachcol, x)
115+
ChainRules.rrule(::typeof(collecteachslice), x; dims) = rrule(eachslice, x; dims=dims)

test/rulesets/Base/indexing.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,34 @@ end
6868
test_frule(setindex!, rand(3, 4), rand(), 1, 2)
6969
test_frule(setindex!, rand(3, 4), [1,10,100.0], :, 3)
7070
end
71+
72+
@testset "eachslice" begin
73+
# Testing eachrow not collect∘eachrow leads to errors, e.g.
74+
# test_rrule: eachrow on Vector{Float64}: Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/8dFTY/src/testers.jl:195
75+
# Got exception outside of a @test
76+
# DimensionMismatch("second dimension of A, 6, does not match length of x, 5")
77+
# Probably similar to https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/234 (about Broadcasted not Generator)
78+
79+
test_rrule(collecteachrow, rand(5))
80+
test_rrule(collecteachrow, rand(3, 4))
81+
82+
test_rrule(collecteachcol, rand(3, 4))
83+
@test_skip test_rrule(collecteachcol, Diagonal(rand(5))) # works locally!
84+
85+
if VERSION >= v"1.7"
86+
# On 1.6, ComposedFunction doesn't take keywords. Only affects this testing strategy, not real use.
87+
test_rrule(collecteachslice, rand(3, 4, 5); fkwargs = (; dims = 3))
88+
test_rrule(collecteachslice, rand(3, 4, 5); fkwargs = (; dims = (2,)))
89+
end
90+
91+
# Make sure pulling back an array that mixes some AbstractZeros in works right
92+
_, back = rrule(eachcol, rand(3, 4))
93+
@test back([1:3, ZeroTangent(), 7:9, NoTangent()]) == (NoTangent(), [1 0 7 0; 2 0 8 0; 3 0 9 0])
94+
@test back([1:3, ZeroTangent(), 7:9, NoTangent()])[2] isa Matrix{Float64}
95+
@test back([ZeroTangent(), ZeroTangent(), NoTangent(), NoTangent()]) == (NoTangent(), [0 0 0 0; 0 0 0 0; 0 0 0 0])
96+
97+
# Second derivative rule
98+
test_rrule(ChainRules.∇eachslice, [rand(4) for _ in 1:3], rand(3, 4), Val(1))
99+
test_rrule(ChainRules.∇eachslice, [rand(3) for _ in 1:4], rand(3, 4), Val(2))
100+
test_rrule(ChainRules.∇eachslice, [rand(2, 3) for _ in 1:4], rand(2, 3, 4), Val(3), check_inferred=false)
101+
end

0 commit comments

Comments
 (0)