@@ -48,4 +48,68 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...)
48
48
end
49
49
50
50
51
+ # ####
52
+ # #### `eachslice` and friends
53
+ # ####
54
+
55
+ function rrule (:: typeof (eachrow), x:: AbstractVecOrMat )
56
+ allrows (dy) = (NoTangent (), ∇eachslice (unthunk (dy), x, Val (1 )))
57
+ return collect (eachrow (x)), allrows
58
+ end
59
+
60
+ function rrule (:: typeof (eachcol), x:: AbstractVecOrMat )
61
+ allcols (dy) = (NoTangent (), ∇eachslice (unthunk (dy), x, Val (2 )))
62
+ return collect (eachcol (x)), allcols
63
+ end
64
+
65
+ function rrule (:: typeof (eachslice), x:: AbstractArray ; dims)
66
+ y = collect (eachslice (x; dims= dims))
67
+ @assert length (dims) == 1 """ That's amazing, after many years JuliaLang/julia#32310
68
+ actually landed. Sadly, the gradient rule for `eachslice` is unable to handle this
69
+ case right now, please make an issue at https://github.com/JuliaDiff/ChainRules.jl"""
70
+ dim = only (dims)
71
+ allslices (dy) = (NoTangent (), ∇eachslice (unthunk (dy), x, Val (dim)))
72
+ return y, allslices
73
+ end
74
+
75
+ # Using Val(dim) here is worth a factor of 2 in this, on Julia 1.8-
76
+ # @btime rrule(eachcol, $([1 2; 3 4]))[2]($([[10, 20], [30, 40]]))
77
+ function ∇eachslice (dys_raw, x:: AbstractArray , vd:: Val{dim} ) where {dim}
78
+ dys = unthunk (dys_raw)
79
+ i1 = findfirst (dy -> dy isa AbstractArray, dys)
80
+ if i1 === nothing # all slices are Zero!
81
+ return _zero_fill! (similar (x, float (eltype (x)), axes (x)))
82
+ end
83
+ T = promote_type (eltype (dys[i1]), eltype (x))
84
+ # The whole point of this gradient is that we can allocate one `dx` array:
85
+ dx = similar (x, T, axes (x))
86
+ for i in axes (x, dim)
87
+ slice = selectdim (dx, dim, i)
88
+ if dys[i] isa AbstractZero
89
+ _zero_fill! (slice) # Avoids this: copyto!([1,2,3], ZeroTangent()) == [0,2,3]
90
+ else
91
+ copyto! (slice, dys[i])
92
+ end
93
+ end
94
+ return ProjectTo (x)(dx)
95
+ end
96
+
97
+ _zero_fill! (dx:: AbstractArray{<:Number} ) = fill! (dx, zero (eltype (dx)))
98
+ _zero_fill! (dx:: AbstractArray ) = map! (zero, dx, dx)
99
+
100
+ function rrule (:: typeof (∇eachslice), dys, x, vd:: Val )
101
+ function ∇∇eachslice (dz_raw)
102
+ dz = unthunk (dz_raw)
103
+ # eachslice(dz; dims=_val(vd)) does not make @code_warntype happy
104
+ iter = vd == Val (1 ) ? eachrow (dz) : vd == Val (2 ) ? eachcol (dz) : eachslice (dz; dims= _val (vd))
105
+ return (NoTangent (), collect (iter), NoTangent (), NoTangent ())
106
+ end
107
+ return ∇eachslice (dys, x, vd), ∇∇eachslice
108
+ end
51
109
110
+ # These rules help with testing, and won't hurt:
111
+ # They are correct as we always `collect` the primal result as we need that
112
+ # information for the reverse pass
113
+ ChainRules. rrule (:: typeof (collect∘ eachrow), x) = rrule (eachrow, x)
114
+ ChainRules. rrule (:: typeof (collect∘ eachcol), x) = rrule (eachcol, x)
115
+ ChainRules. rrule (:: typeof (collect∘ eachslice), x; dims) = rrule (eachslice, x; dims= dims)
0 commit comments