Skip to content

Commit 8c34f19

Browse files
mcabbottoxinabox
andauthored
Add many frules (#565)
* drop 1.0, now that LTS == 1.6 * revert to one Project * rm Compat * turns out this does still need Compat * add many frules * in-place frules * reshape + dropdims too * tests * 5-arg mul * notation changes * rm 2nd order rules * don't skip setindex * AbstractArray constructors * reshape tests * Apply 4 suggestions Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> * fixup, bump * several comments, and one rule for PermutedDimsArray * in fact sortslices is fine with offsets Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent 13a362c commit 8c34f19

File tree

13 files changed

+418
-37
lines changed

13 files changed

+418
-37
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.23"
3+
version = "1.24"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/array.jl

Lines changed: 140 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,37 @@
44

55
ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...)
66

7+
function frule((_, ẋ), ::Type{T}, x::AbstractArray) where {T<:Array}
8+
return T(x), T(ẋ)
9+
end
10+
11+
function frule((_, ẋ), ::Type{AbstractArray{T}}, x::AbstractArray) where {T}
12+
return AbstractArray{T}(x), AbstractArray{T}(ẋ)
13+
end
14+
715
function rrule(::Type{T}, x::AbstractArray) where {T<:Array}
816
project_x = ProjectTo(x)
917
Array_pullback(ȳ) = (NoTangent(), project_x(ȳ))
1018
return T(x), Array_pullback
1119
end
1220

21+
# This abstract one is used for `float(x)` and other float conversion purposes:
22+
function rrule(::Type{AbstractArray{T}}, x::AbstractArray) where {T}
23+
project_x = ProjectTo(x)
24+
AbstractArray_pullback(ȳ) = (NoTangent(), project_x(ȳ))
25+
return AbstractArray{T}(x), AbstractArray_pullback
26+
end
27+
1328
#####
1429
##### `vect`
1530
#####
1631

1732
@non_differentiable Base.vect()
1833

34+
function frule((_, ẋs...), ::typeof(Base.vect), xs::Number...)
35+
return Base.vect(xs...), Base.vect(_instantiate_zeros(ẋs, xs)...)
36+
end
37+
1938
# Case of uniform type `T`: the data passes straight through,
2039
# so no projection should be required.
2140
function rrule(::typeof(Base.vect), X::Vararg{T, N}) where {T, N}
@@ -43,32 +62,84 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N}
4362
return Base.vect(X...), vect_pullback
4463
end
4564

65+
"""
66+
_instantiate_zeros(ẋs, xs)
67+
68+
Forward rules for `vect`, `cat` etc may receive a mixture of data and `ZeroTangent`s.
69+
To avoid `vect(1, ZeroTangent(), 3)` or worse `vcat([1,2], ZeroTangent(), [6,7])`, this
70+
materialises each zero `ẋ` to be `zero(x)`.
71+
"""
72+
_instantiate_zeros(ẋs, xs) = map(_i_zero, ẋs, xs)
73+
_i_zero(ẋ, x) =
74+
_i_zero(ẋ::AbstractZero, x) = zero(x)
75+
# Possibly this won't work for partly non-diff arrays, sometihng like `gradient(x -> ["abc", x][end], 1)`
76+
# may give a MethodError for `zero` but won't be wrong.
77+
78+
# Fast paths. Should it also collapse all-Zero cases?
79+
_instantiate_zeros(ẋs::Tuple{Vararg{<:Number}}, xs) = ẋs
80+
_instantiate_zeros(ẋs::Tuple{Vararg{<:AbstractArray}}, xs) = ẋs
81+
_instantiate_zeros(ẋs::AbstractArray{<:Number}, xs) = ẋs
82+
_instantiate_zeros(ẋs::AbstractArray{<:AbstractArray}, xs) = ẋs
83+
84+
#####
85+
##### `copyto!`
86+
#####
87+
88+
function frule((_, ẏ, ẋ), ::typeof(copyto!), y::AbstractArray, x)
89+
return copyto!(y, x), copyto!(ẏ, ẋ)
90+
end
91+
92+
function frule((_, ẏ, _, ẋ), ::typeof(copyto!), y::AbstractArray, i::Integer, x, js::Integer...)
93+
return copyto!(y, i, x, js...), copyto!(ẏ, i, ẋ, js...)
94+
end
95+
4696
#####
4797
##### `reshape`
4898
#####
4999

50-
function rrule(::typeof(reshape), A::AbstractArray, dims::Tuple{Vararg{Union{Colon,Int}}})
51-
A_dims = size(A)
52-
function reshape_pullback(Ȳ)
53-
return (NoTangent(), reshape(Ȳ, A_dims), NoTangent())
54-
end
55-
return reshape(A, dims), reshape_pullback
100+
function frule((_, ẋ), ::typeof(reshape), x::AbstractArray, dims...)
101+
return reshape(x, dims...), reshape(ẋ, dims...)
56102
end
57103

58-
function rrule(::typeof(reshape), A::AbstractArray, dims::Union{Colon,Int}...)
59-
A_dims = size(A)
60-
function reshape_pullback(Ȳ)
61-
∂A = reshape(Ȳ, A_dims)
62-
∂dims = broadcast(Returns(NoTangent()), dims)
63-
return (NoTangent(), ∂A, ∂dims...)
64-
end
104+
function rrule(::typeof(reshape), A::AbstractArray, dims...)
105+
ax = axes(A)
106+
project = ProjectTo(A) # Projection is here for e.g. reshape(::Diagonal, :)
107+
∂dims = broadcast(Returns(NoTangent()), dims)
108+
reshape_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax)), ∂dims...)
65109
return reshape(A, dims...), reshape_pullback
66110
end
67111

112+
#####
113+
##### `dropdims`
114+
#####
115+
116+
function frule((_, ẋ), ::typeof(dropdims), x::AbstractArray; dims)
117+
return dropdims(x; dims), dropdims(ẋ; dims)
118+
end
119+
120+
function rrule(::typeof(dropdims), A::AbstractArray; dims)
121+
ax = axes(A)
122+
project = ProjectTo(A)
123+
dropdims_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax)))
124+
return dropdims(A; dims), dropdims_pullback
125+
end
126+
68127
#####
69128
##### `permutedims`
70129
#####
71130

131+
function frule((_, ẋ), ::typeof(permutedims), x::AbstractArray, perm...)
132+
return permutedims(x, perm...), permutedims(ẋ, perm...)
133+
end
134+
135+
function frule((_, ẏ, ẋ), ::typeof(permutedims!), y::AbstractArray, x::AbstractArray, perm...)
136+
return permutedims!(y, x, perm...), permutedims!(ẏ, ẋ, perm...)
137+
end
138+
139+
function frule((_, ẋ), ::Type{<:PermutedDimsArray}, x::AbstractArray, perm)
140+
return PermutedDimsArray(x, perm), PermutedDimsArray(ẋ, perm)
141+
end
142+
72143
function rrule(::typeof(permutedims), x::AbstractVector)
73144
project = ProjectTo(x)
74145
permutedims_pullback_1(dy) = (NoTangent(), project(permutedims(unthunk(dy))))
@@ -91,6 +162,10 @@ end
91162
##### `repeat`
92163
#####
93164

165+
function frule((_, ẋs), ::typeof(repeat), xs::AbstractArray, cnt...; kw...)
166+
return repeat(xs, cnt...; kw...), repeat(ẋs, cnt...; kw...)
167+
end
168+
94169
function rrule(::typeof(repeat), xs::AbstractArray; inner=ntuple(Returns(1), ndims(xs)), outer=ntuple(Returns(1), ndims(xs)))
95170

96171
project_Xs = ProjectTo(xs)
@@ -130,6 +205,10 @@ end
130205
##### `hcat`
131206
#####
132207

208+
function frule((_, ẋs...), ::typeof(hcat), xs...)
209+
return hcat(xs...), hcat(_instantiate_zeros(ẋs, xs)...)
210+
end
211+
133212
function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
134213
Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray
135214
ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability
@@ -164,6 +243,10 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
164243
return Y, hcat_pullback
165244
end
166245

246+
function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat})
247+
return reduce(hcat, As), reduce(hcat, _instantiate_zeros(Ȧs, As))
248+
end
249+
167250
function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat})
168251
widths = map(A -> size(A,2), As)
169252
function reduce_hcat_pullback_2(dY)
@@ -192,6 +275,10 @@ end
192275
##### `vcat`
193276
#####
194277

278+
function frule((_, ẋs...), ::typeof(vcat), xs...)
279+
return vcat(xs...), vcat(_instantiate_zeros(ẋs, xs)...)
280+
end
281+
195282
function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
196283
Y = vcat(Xs...)
197284
ndimsY = Val(ndims(Y))
@@ -224,6 +311,10 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
224311
return Y, vcat_pullback
225312
end
226313

314+
function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat})
315+
return reduce(vcat, As), reduce(vcat, _instantiate_zeros(Ȧs, As))
316+
end
317+
227318
function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat})
228319
Y = reduce(vcat, As)
229320
ndimsY = Val(ndims(Y))
@@ -247,6 +338,10 @@ end
247338

248339
_val(::Val{x}) where {x} = x
249340

341+
function frule((_, ẋs...), ::typeof(cat), xs...; dims)
342+
return cat(xs...; dims), cat(_instantiate_zeros(ẋs, xs)...; dims)
343+
end
344+
250345
function rrule(::typeof(cat), Xs::Union{AbstractArray, Number}...; dims)
251346
Y = cat(Xs...; dims=dims)
252347
cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims)
@@ -285,6 +380,10 @@ end
285380
##### `hvcat`
286381
#####
287382

383+
function frule((_, _, ẋs...), ::typeof(hvcat), rows, xs...)
384+
return hvcat(rows, xs...), hvcat(rows, _instantiate_zeros(ẋs, xs)...)
385+
end
386+
288387
function rrule(::typeof(hvcat), rows, values::Union{AbstractArray, Number}...)
289388
Y = hvcat(rows, values...)
290389
cols = size(Y,2)
@@ -321,8 +420,12 @@ end
321420
# 1-dim case allows start/stop, N-dim case takes dims keyword
322421
# whose defaults changed in Julia 1.6... just pass them all through:
323422

324-
function frule((_, xdot), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
325-
return reverse(x, args...; kw...), reverse(xdot, args...; kw...)
423+
function frule((_, ẋ), ::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
424+
return reverse(x, args...; kw...), reverse(ẋ, args...; kw...)
425+
end
426+
427+
function frule((_, ẋ), ::typeof(reverse!), x::Union{AbstractArray, Tuple}, args...; kw...)
428+
return reverse!(x, args...; kw...), reverse!(ẋ, args...; kw...)
326429
end
327430

328431
function rrule(::typeof(reverse), x::Union{AbstractArray, Tuple}, args...; kw...)
@@ -338,8 +441,12 @@ end
338441
##### `circshift`
339442
#####
340443

341-
function frule((_, xdot), ::typeof(circshift), x::AbstractArray, shifts)
342-
return circshift(x, shifts), circshift(xdot, shifts)
444+
function frule((_, ẋ), ::typeof(circshift), x::AbstractArray, shifts)
445+
return circshift(x, shifts), circshift(ẋ, shifts)
446+
end
447+
448+
function frule((_, ẏ, ẋ), ::typeof(circshift!), y::AbstractArray, x::AbstractArray, shifts)
449+
return circshift!(y, x, shifts), circshift!(ẏ, ẋ, shifts)
343450
end
344451

345452
function rrule(::typeof(circshift), x::AbstractArray, shifts)
@@ -355,8 +462,12 @@ end
355462
##### `fill`
356463
#####
357464

358-
function frule((_, xdot), ::typeof(fill), x::Any, dims...)
359-
return fill(x, dims...), fill(xdot, dims...)
465+
function frule((_, ẋ), ::typeof(fill), x::Any, dims...)
466+
return fill(x, dims...), fill(ẋ, dims...)
467+
end
468+
469+
function frule((_, ẏ, ẋ), ::typeof(fill!), y::AbstractArray, x::Any)
470+
return fill!(y, x), fill!(ẏ, ẋ)
360471
end
361472

362473
function rrule(::typeof(fill), x::Any, dims...)
@@ -370,9 +481,9 @@ end
370481
##### `filter`
371482
#####
372483

373-
function frule((_, _, xdot), ::typeof(filter), f, x::AbstractArray)
484+
function frule((_, _, ), ::typeof(filter), f, x::AbstractArray)
374485
inds = findall(f, x)
375-
return x[inds], xdot[inds]
486+
return x[inds], [inds]
376487
end
377488

378489
function rrule(::typeof(filter), f, x::AbstractArray)
@@ -392,9 +503,9 @@ end
392503
for findm in (:findmin, :findmax)
393504
findm_pullback = Symbol(findm, :_pullback)
394505

395-
@eval function frule((_, xdot), ::typeof($findm), x; dims=:)
506+
@eval function frule((_, ), ::typeof($findm), x; dims=:)
396507
y, ind = $findm(x; dims=dims)
397-
return (y, ind), Tangent{typeof((y, ind))}(xdot[ind], NoTangent())
508+
return (y, ind), Tangent{typeof((y, ind))}([ind], NoTangent())
398509
end
399510

400511
@eval function rrule(::typeof($findm), x::AbstractArray; dims=:)
@@ -441,8 +552,8 @@ end
441552
# Allow for second derivatives, by writing rules for `_zerolike_writeat`;
442553
# these rules are the reason it takes a `dims` argument.
443554

444-
function frule((_, _, dydot), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
445-
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dydot, dims, inds...)
555+
function frule((_, _, dẏ), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
556+
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dẏ, dims, inds...)
446557
end
447558

448559
function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...)
@@ -457,9 +568,9 @@ end
457568

458569
# These rules for `maximum` pick the same subgradient as `findmax`:
459570

460-
function frule((_, xdot), ::typeof(maximum), x; dims=:)
571+
function frule((_, ), ::typeof(maximum), x; dims=:)
461572
y, ind = findmax(x; dims=dims)
462-
return y, xdot[ind]
573+
return y, [ind]
463574
end
464575

465576
function rrule(::typeof(maximum), x::AbstractArray; dims=:)
@@ -468,9 +579,9 @@ function rrule(::typeof(maximum), x::AbstractArray; dims=:)
468579
return y, maximum_pullback
469580
end
470581

471-
function frule((_, xdot), ::typeof(minimum), x; dims=:)
582+
function frule((_, ), ::typeof(minimum), x; dims=:)
472583
y, ind = findmin(x; dims=dims)
473-
return y, xdot[ind]
584+
return y, [ind]
474585
end
475586

476587
function rrule(::typeof(minimum), x::AbstractArray; dims=:)

src/rulesets/Base/arraymath.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ end
1919
##### `*`
2020
#####
2121

22+
frule((_, ΔA, ΔB), ::typeof(*), A, B) = A * B, muladd(ΔA, B, A * ΔB)
23+
24+
frule((_, ΔA, ΔB, ΔC), ::typeof(*), A, B, C) = A*B*C, ΔA*B*C + A*ΔB*C + A*B*ΔC
25+
2226

2327
function rrule(
2428
::typeof(*),
@@ -88,7 +92,9 @@ function rrule(
8892
end
8993

9094

91-
95+
#####
96+
##### `*` matrix-scalar_rule
97+
#####
9298

9399
function rrule(
94100
::typeof(*), A::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber}
@@ -204,6 +210,11 @@ end # VERSION
204210
##### `muladd`
205211
#####
206212

213+
function frule((_, ΔA, ΔB, Δz), ::typeof(muladd), A, B, z)
214+
Ω = muladd(A, B, z)
215+
return Ω, ΔA * B .+ A * ΔB .+ Δz
216+
end
217+
207218
function rrule(
208219
::typeof(muladd),
209220
A::AbstractMatrix{<:CommutativeMulNumber},
@@ -351,6 +362,13 @@ end
351362
##### `\`, `/` matrix-scalar_rule
352363
#####
353364

365+
function frule((_, ΔA, Δb), ::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber)
366+
return A/b, ΔA/b - A*(Δb/b^2)
367+
end
368+
function frule((_, Δa, ΔB), ::typeof(\), a::CommutativeMulNumber, B::AbstractArray{<:CommutativeMulNumber})
369+
return B/a, ΔB/a - B*(Δa/a^2)
370+
end
371+
354372
function rrule(::typeof(/), A::AbstractArray{<:CommutativeMulNumber}, b::CommutativeMulNumber)
355373
Y = A/b
356374
function slash_pullback_scalar(ȳ)
@@ -378,6 +396,8 @@ end
378396
##### Negation (Unary -)
379397
#####
380398

399+
frule((_, ΔA), ::typeof(-), A::AbstractArray) = -A, -ΔA
400+
381401
function rrule(::typeof(-), x::AbstractArray)
382402
function negation_pullback(ȳ)
383403
return NoTangent(), InplaceableThunk(ā ->.-= ȳ, @thunk(-ȳ))
@@ -390,6 +410,8 @@ end
390410
##### Addition (Multiarg `+`)
391411
#####
392412

413+
frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...)
414+
393415
function rrule(::typeof(+), arrs::AbstractArray...)
394416
y = +(arrs...)
395417
arr_axs = map(axes, arrs)

0 commit comments

Comments
 (0)