4
4
5
5
ChainRules. @non_differentiable (:: Type{T} where {T<: Array })(:: UndefInitializer , args... )
6
6
7
+ function frule ((_, ẋ), :: Type{T} , x:: AbstractArray ) where {T<: Array }
8
+ return T (x), T (ẋ)
9
+ end
10
+
11
+ function frule ((_, ẋ), :: Type{AbstractArray{T}} , x:: AbstractArray ) where {T}
12
+ return AbstractArray {T} (x), AbstractArray {T} (ẋ)
13
+ end
14
+
7
15
function rrule (:: Type{T} , x:: AbstractArray ) where {T<: Array }
8
16
project_x = ProjectTo (x)
9
17
Array_pullback (ȳ) = (NoTangent (), project_x (ȳ))
10
18
return T (x), Array_pullback
11
19
end
12
20
21
+ # This abstract one is used for `float(x)` and other float conversion purposes:
22
+ function rrule (:: Type{AbstractArray{T}} , x:: AbstractArray ) where {T}
23
+ project_x = ProjectTo (x)
24
+ AbstractArray_pullback (ȳ) = (NoTangent (), project_x (ȳ))
25
+ return AbstractArray {T} (x), AbstractArray_pullback
26
+ end
27
+
13
28
# ####
14
29
# #### `vect`
15
30
# ####
16
31
17
32
@non_differentiable Base. vect ()
18
33
34
+ function frule ((_, ẋs... ), :: typeof (Base. vect), xs:: Number... )
35
+ return Base. vect (xs... ), Base. vect (_instantiate_zeros (ẋs, xs)... )
36
+ end
37
+
19
38
# Case of uniform type `T`: the data passes straight through,
20
39
# so no projection should be required.
21
40
function rrule (:: typeof (Base. vect), X:: Vararg{T, N} ) where {T, N}
@@ -43,32 +62,84 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N}
43
62
return Base. vect (X... ), vect_pullback
44
63
end
45
64
65
+ """
66
+ _instantiate_zeros(ẋs, xs)
67
+
68
+ Forward rules for `vect`, `cat` etc may receive a mixture of data and `ZeroTangent`s.
69
+ To avoid `vect(1, ZeroTangent(), 3)` or worse `vcat([1,2], ZeroTangent(), [6,7])`, this
70
+ materialises each zero `ẋ` to be `zero(x)`.
71
+ """
72
+ _instantiate_zeros (ẋs, xs) = map (_i_zero, ẋs, xs)
73
+ _i_zero (ẋ, x) = ẋ
74
+ _i_zero (ẋ:: AbstractZero , x) = zero (x)
75
+ # Possibly this won't work for partly non-diff arrays, sometihng like `gradient(x -> ["abc", x][end], 1)`
76
+ # may give a MethodError for `zero` but won't be wrong.
77
+
78
+ # Fast paths. Should it also collapse all-Zero cases?
79
+ _instantiate_zeros (ẋs:: Tuple{Vararg{<:Number}} , xs) = ẋs
80
+ _instantiate_zeros (ẋs:: Tuple{Vararg{<:AbstractArray}} , xs) = ẋs
81
+ _instantiate_zeros (ẋs:: AbstractArray{<:Number} , xs) = ẋs
82
+ _instantiate_zeros (ẋs:: AbstractArray{<:AbstractArray} , xs) = ẋs
83
+
84
+ # ####
85
+ # #### `copyto!`
86
+ # ####
87
+
88
+ function frule ((_, ẏ, ẋ), :: typeof (copyto!), y:: AbstractArray , x)
89
+ return copyto! (y, x), copyto! (ẏ, ẋ)
90
+ end
91
+
92
+ function frule ((_, ẏ, _, ẋ), :: typeof (copyto!), y:: AbstractArray , i:: Integer , x, js:: Integer... )
93
+ return copyto! (y, i, x, js... ), copyto! (ẏ, i, ẋ, js... )
94
+ end
95
+
46
96
# ####
47
97
# #### `reshape`
48
98
# ####
49
99
50
- function rrule (:: typeof (reshape), A:: AbstractArray , dims:: Tuple{Vararg{Union{Colon,Int}}} )
51
- A_dims = size (A)
52
- function reshape_pullback (Ȳ)
53
- return (NoTangent (), reshape (Ȳ, A_dims), NoTangent ())
54
- end
55
- return reshape (A, dims), reshape_pullback
100
+ function frule ((_, ẋ), :: typeof (reshape), x:: AbstractArray , dims... )
101
+ return reshape (x, dims... ), reshape (ẋ, dims... )
56
102
end
57
103
58
- function rrule (:: typeof (reshape), A:: AbstractArray , dims:: Union{Colon,Int} ...)
59
- A_dims = size (A)
60
- function reshape_pullback (Ȳ)
61
- ∂A = reshape (Ȳ, A_dims)
62
- ∂dims = broadcast (Returns (NoTangent ()), dims)
63
- return (NoTangent (), ∂A, ∂dims... )
64
- end
104
+ function rrule (:: typeof (reshape), A:: AbstractArray , dims... )
105
+ ax = axes (A)
106
+ project = ProjectTo (A) # Projection is here for e.g. reshape(::Diagonal, :)
107
+ ∂dims = broadcast (Returns (NoTangent ()), dims)
108
+ reshape_pullback (Ȳ) = (NoTangent (), project (reshape (Ȳ, ax)), ∂dims... )
65
109
return reshape (A, dims... ), reshape_pullback
66
110
end
67
111
112
+ # ####
113
+ # #### `dropdims`
114
+ # ####
115
+
116
+ function frule ((_, ẋ), :: typeof (dropdims), x:: AbstractArray ; dims)
117
+ return dropdims (x; dims), dropdims (ẋ; dims)
118
+ end
119
+
120
+ function rrule (:: typeof (dropdims), A:: AbstractArray ; dims)
121
+ ax = axes (A)
122
+ project = ProjectTo (A)
123
+ dropdims_pullback (Ȳ) = (NoTangent (), project (reshape (Ȳ, ax)))
124
+ return dropdims (A; dims), dropdims_pullback
125
+ end
126
+
68
127
# ####
69
128
# #### `permutedims`
70
129
# ####
71
130
131
+ function frule ((_, ẋ), :: typeof (permutedims), x:: AbstractArray , perm... )
132
+ return permutedims (x, perm... ), permutedims (ẋ, perm... )
133
+ end
134
+
135
+ function frule ((_, ẏ, ẋ), :: typeof (permutedims!), y:: AbstractArray , x:: AbstractArray , perm... )
136
+ return permutedims! (y, x, perm... ), permutedims! (ẏ, ẋ, perm... )
137
+ end
138
+
139
+ function frule ((_, ẋ), :: Type{<:PermutedDimsArray} , x:: AbstractArray , perm)
140
+ return PermutedDimsArray (x, perm), PermutedDimsArray (ẋ, perm)
141
+ end
142
+
72
143
function rrule (:: typeof (permutedims), x:: AbstractVector )
73
144
project = ProjectTo (x)
74
145
permutedims_pullback_1 (dy) = (NoTangent (), project (permutedims (unthunk (dy))))
91
162
# #### `repeat`
92
163
# ####
93
164
165
+ function frule ((_, ẋs), :: typeof (repeat), xs:: AbstractArray , cnt... ; kw... )
166
+ return repeat (xs, cnt... ; kw... ), repeat (ẋs, cnt... ; kw... )
167
+ end
168
+
94
169
function rrule (:: typeof (repeat), xs:: AbstractArray ; inner= ntuple (Returns (1 ), ndims (xs)), outer= ntuple (Returns (1 ), ndims (xs)))
95
170
96
171
project_Xs = ProjectTo (xs)
130
205
# #### `hcat`
131
206
# ####
132
207
208
+ function frule ((_, ẋs... ), :: typeof (hcat), xs... )
209
+ return hcat (xs... ), hcat (_instantiate_zeros (ẋs, xs)... )
210
+ end
211
+
133
212
function rrule (:: typeof (hcat), Xs:: Union{AbstractArray, Number} ...)
134
213
Y = hcat (Xs... ) # note that Y always has 1-based indexing, even if X isa OffsetArray
135
214
ndimsY = Val (ndims (Y)) # this avoids closing over Y, Val() is essential for type-stability
@@ -164,6 +243,10 @@ function rrule(::typeof(hcat), Xs::Union{AbstractArray, Number}...)
164
243
return Y, hcat_pullback
165
244
end
166
245
246
+ function frule ((_, _, Ȧs), :: typeof (reduce), :: typeof (hcat), As:: AbstractVector{<:AbstractVecOrMat} )
247
+ return reduce (hcat, As), reduce (hcat, _instantiate_zeros (Ȧs, As))
248
+ end
249
+
167
250
function rrule (:: typeof (reduce), :: typeof (hcat), As:: AbstractVector{<:AbstractVecOrMat} )
168
251
widths = map (A -> size (A,2 ), As)
169
252
function reduce_hcat_pullback_2 (dY)
192
275
# #### `vcat`
193
276
# ####
194
277
278
+ function frule ((_, ẋs... ), :: typeof (vcat), xs... )
279
+ return vcat (xs... ), vcat (_instantiate_zeros (ẋs, xs)... )
280
+ end
281
+
195
282
function rrule (:: typeof (vcat), Xs:: Union{AbstractArray, Number} ...)
196
283
Y = vcat (Xs... )
197
284
ndimsY = Val (ndims (Y))
@@ -224,6 +311,10 @@ function rrule(::typeof(vcat), Xs::Union{AbstractArray, Number}...)
224
311
return Y, vcat_pullback
225
312
end
226
313
314
+ function frule ((_, _, Ȧs), :: typeof (reduce), :: typeof (vcat), As:: AbstractVector{<:AbstractVecOrMat} )
315
+ return reduce (vcat, As), reduce (vcat, _instantiate_zeros (Ȧs, As))
316
+ end
317
+
227
318
function rrule (:: typeof (reduce), :: typeof (vcat), As:: AbstractVector{<:AbstractVecOrMat} )
228
319
Y = reduce (vcat, As)
229
320
ndimsY = Val (ndims (Y))
247
338
248
339
_val (:: Val{x} ) where {x} = x
249
340
341
+ function frule ((_, ẋs... ), :: typeof (cat), xs... ; dims)
342
+ return cat (xs... ; dims), cat (_instantiate_zeros (ẋs, xs)... ; dims)
343
+ end
344
+
250
345
function rrule (:: typeof (cat), Xs:: Union{AbstractArray, Number} ...; dims)
251
346
Y = cat (Xs... ; dims= dims)
252
347
cdims = dims isa Val ? Int (_val (dims)) : dims isa Integer ? Int (dims) : Tuple (dims)
285
380
# #### `hvcat`
286
381
# ####
287
382
383
+ function frule ((_, _, ẋs... ), :: typeof (hvcat), rows, xs... )
384
+ return hvcat (rows, xs... ), hvcat (rows, _instantiate_zeros (ẋs, xs)... )
385
+ end
386
+
288
387
function rrule (:: typeof (hvcat), rows, values:: Union{AbstractArray, Number} ...)
289
388
Y = hvcat (rows, values... )
290
389
cols = size (Y,2 )
321
420
# 1-dim case allows start/stop, N-dim case takes dims keyword
322
421
# whose defaults changed in Julia 1.6... just pass them all through:
323
422
324
- function frule ((_, xdot), :: typeof (reverse), x:: Union{AbstractArray, Tuple} , args... ; kw... )
325
- return reverse (x, args... ; kw... ), reverse (xdot, args... ; kw... )
423
+ function frule ((_, ẋ), :: typeof (reverse), x:: Union{AbstractArray, Tuple} , args... ; kw... )
424
+ return reverse (x, args... ; kw... ), reverse (ẋ, args... ; kw... )
425
+ end
426
+
427
+ function frule ((_, ẋ), :: typeof (reverse!), x:: Union{AbstractArray, Tuple} , args... ; kw... )
428
+ return reverse! (x, args... ; kw... ), reverse! (ẋ, args... ; kw... )
326
429
end
327
430
328
431
function rrule (:: typeof (reverse), x:: Union{AbstractArray, Tuple} , args... ; kw... )
338
441
# #### `circshift`
339
442
# ####
340
443
341
- function frule ((_, xdot), :: typeof (circshift), x:: AbstractArray , shifts)
342
- return circshift (x, shifts), circshift (xdot, shifts)
444
+ function frule ((_, ẋ), :: typeof (circshift), x:: AbstractArray , shifts)
445
+ return circshift (x, shifts), circshift (ẋ, shifts)
446
+ end
447
+
448
+ function frule ((_, ẏ, ẋ), :: typeof (circshift!), y:: AbstractArray , x:: AbstractArray , shifts)
449
+ return circshift! (y, x, shifts), circshift! (ẏ, ẋ, shifts)
343
450
end
344
451
345
452
function rrule (:: typeof (circshift), x:: AbstractArray , shifts)
355
462
# #### `fill`
356
463
# ####
357
464
358
- function frule ((_, xdot), :: typeof (fill), x:: Any , dims... )
359
- return fill (x, dims... ), fill (xdot, dims... )
465
+ function frule ((_, ẋ), :: typeof (fill), x:: Any , dims... )
466
+ return fill (x, dims... ), fill (ẋ, dims... )
467
+ end
468
+
469
+ function frule ((_, ẏ, ẋ), :: typeof (fill!), y:: AbstractArray , x:: Any )
470
+ return fill! (y, x), fill! (ẏ, ẋ)
360
471
end
361
472
362
473
function rrule (:: typeof (fill), x:: Any , dims... )
370
481
# #### `filter`
371
482
# ####
372
483
373
- function frule ((_, _, xdot ), :: typeof (filter), f, x:: AbstractArray )
484
+ function frule ((_, _, ẋ ), :: typeof (filter), f, x:: AbstractArray )
374
485
inds = findall (f, x)
375
- return x[inds], xdot [inds]
486
+ return x[inds], ẋ [inds]
376
487
end
377
488
378
489
function rrule (:: typeof (filter), f, x:: AbstractArray )
392
503
for findm in (:findmin , :findmax )
393
504
findm_pullback = Symbol (findm, :_pullback )
394
505
395
- @eval function frule ((_, xdot ), :: typeof ($ findm), x; dims= :)
506
+ @eval function frule ((_, ẋ ), :: typeof ($ findm), x; dims= :)
396
507
y, ind = $ findm (x; dims= dims)
397
- return (y, ind), Tangent {typeof((y, ind))} (xdot [ind], NoTangent ())
508
+ return (y, ind), Tangent {typeof((y, ind))} (ẋ [ind], NoTangent ())
398
509
end
399
510
400
511
@eval function rrule (:: typeof ($ findm), x:: AbstractArray ; dims= :)
441
552
# Allow for second derivatives, by writing rules for `_zerolike_writeat`;
442
553
# these rules are the reason it takes a `dims` argument.
443
554
444
- function frule ((_, _, dydot ), :: typeof (_zerolike_writeat), x, dy, dims, inds... )
445
- return _zerolike_writeat (x, dy, dims, inds... ), _zerolike_writeat (x, dydot , dims, inds... )
555
+ function frule ((_, _, dẏ ), :: typeof (_zerolike_writeat), x, dy, dims, inds... )
556
+ return _zerolike_writeat (x, dy, dims, inds... ), _zerolike_writeat (x, dẏ , dims, inds... )
446
557
end
447
558
448
559
function rrule (:: typeof (_zerolike_writeat), x, dy, dims, inds... )
457
568
458
569
# These rules for `maximum` pick the same subgradient as `findmax`:
459
570
460
- function frule ((_, xdot ), :: typeof (maximum), x; dims= :)
571
+ function frule ((_, ẋ ), :: typeof (maximum), x; dims= :)
461
572
y, ind = findmax (x; dims= dims)
462
- return y, xdot [ind]
573
+ return y, ẋ [ind]
463
574
end
464
575
465
576
function rrule (:: typeof (maximum), x:: AbstractArray ; dims= :)
@@ -468,9 +579,9 @@ function rrule(::typeof(maximum), x::AbstractArray; dims=:)
468
579
return y, maximum_pullback
469
580
end
470
581
471
- function frule ((_, xdot ), :: typeof (minimum), x; dims= :)
582
+ function frule ((_, ẋ ), :: typeof (minimum), x; dims= :)
472
583
y, ind = findmin (x; dims= dims)
473
- return y, xdot [ind]
584
+ return y, ẋ [ind]
474
585
end
475
586
476
587
function rrule (:: typeof (minimum), x:: AbstractArray ; dims= :)
0 commit comments