1
+
2
+ rrule (f:: Function , args... ; kwargs... ) = (@error " no rrule defined!" f args ; nothing )
3
+
1
4
# ####
2
5
# #### `norm`
3
6
# ####
@@ -17,20 +20,25 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
17
20
return y, ∂y
18
21
end
19
22
20
- function rrule (
21
- :: typeof (norm),
22
- x:: Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal} ,
23
- p:: Real ,
24
- )
23
+ function rrule (:: typeof (norm), x:: AbstractArray , p:: Real )
25
24
y = LinearAlgebra. norm (x, p)
26
- function norm_pullback (Δy)
25
+ function norm_pullback_p (Δy)
27
26
∂x = Thunk () do
28
27
return if isempty (x) || p == 0
29
- zero .(x) .* (zero (y) * zero (real (Δy)))
28
+ InplaceableThunk (
29
+ @thunk (zero .(x) .* (zero (y) * zero (real (Δy)))),
30
+ dx -> dx .= zero (eltype (dx)),
31
+ )
30
32
elseif p == 2
31
- _norm2_back (x, y, Δy)
33
+ InplaceableThunk (
34
+ @thunk (_norm2_back (x, y, Δy)),
35
+ dx -> _norm2_back! (dx, x, y, Δy),
36
+ )
32
37
elseif p == 1
33
- _norm1_back (x, y, Δy)
38
+ InplaceableThunk (
39
+ @thunk (_norm1_back (x, y, Δy)),
40
+ dx -> _norm1_back! (dx, x, y, Δy),
41
+ )
34
42
elseif p == Inf
35
43
_normInf_back (x, y, Δy)
36
44
elseif p == - Inf
@@ -42,24 +50,24 @@ function rrule(
42
50
∂p = @thunk _normp_back_p (x, p, y, Δy)
43
51
return (NO_FIELDS, ∂x, ∂p)
44
52
end
45
- norm_pullback (:: Zero ) = (NO_FIELDS, Zero (), Zero ())
46
- return y, norm_pullback
53
+ norm_pullback_p (:: Zero ) = (NO_FIELDS, Zero (), Zero ())
54
+ return y, norm_pullback_p
47
55
end
48
- function rrule (
49
- :: typeof (norm),
50
- x:: Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal} ,
51
- )
56
+ function rrule (:: typeof (norm), x:: AbstractArray )
52
57
y = LinearAlgebra. norm (x)
53
- function norm_pullback (Δy)
58
+ function norm_pullback_2 (Δy)
54
59
∂x = if isempty (x)
55
60
zero .(x) .* (zero (y) * zero (real (Δy)))
56
61
else
57
- _norm2_back (x, y, Δy)
62
+ InplaceableThunk (
63
+ @thunk (_norm2_back (x, y, Δy)),
64
+ dx -> _norm2_back! (dx, x, y, Δy),
65
+ )
58
66
end
59
67
return (NO_FIELDS, ∂x)
60
68
end
61
- norm_pullback (:: Zero ) = (NO_FIELDS, Zero ())
62
- return y, norm_pullback
69
+ norm_pullback_2 (:: Zero ) = (NO_FIELDS, Zero ())
70
+ return y, norm_pullback_2
63
71
end
64
72
function rrule (
65
73
:: typeof (norm),
94
102
# #### `normp`
95
103
# ####
96
104
97
- function rrule (
98
- :: typeof (LinearAlgebra. normp),
99
- x:: Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal} ,
100
- p,
101
- )
105
+ function rrule (:: typeof (LinearAlgebra. normp),x:: AbstractArray , p)
102
106
y = LinearAlgebra. normp (x, p)
103
107
function normp_pullback (Δy)
104
108
∂x = @thunk _normp_back_x (x, p, y, Δy)
@@ -135,20 +139,14 @@ end
135
139
# #### `normMinusInf`/`normInf`
136
140
# ####
137
141
138
- function rrule (
139
- :: typeof (LinearAlgebra. normMinusInf),
140
- x:: Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal} ,
141
- )
142
+ function rrule (:: typeof (LinearAlgebra. normMinusInf), x:: AbstractArray )
142
143
y = LinearAlgebra. normMinusInf (x)
143
144
normMinusInf_pullback (Δy) = (NO_FIELDS, _normInf_back (x, y, Δy))
144
145
normMinusInf_pullback (:: Zero ) = (NO_FIELDS, Zero ())
145
146
return y, normMinusInf_pullback
146
147
end
147
148
148
- function rrule (
149
- :: typeof (LinearAlgebra. normInf),
150
- x:: Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal} ,
151
- )
149
+ function rrule (:: typeof (LinearAlgebra. normInf), x:: AbstractArray )
152
150
y = LinearAlgebra. normInf (x)
153
151
normInf_pullback (Δy) = (NO_FIELDS, _normInf_back (x, y, Δy))
154
152
normInf_pullback (:: Zero ) = (NO_FIELDS, Zero ())
@@ -172,12 +170,12 @@ end
172
170
# #### `norm1`
173
171
# ####
174
172
175
- function rrule (
176
- :: typeof (LinearAlgebra. norm1),
177
- x:: Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal} ,
178
- )
173
+ function rrule (:: typeof (LinearAlgebra. norm1), x:: AbstractArray )
179
174
y = LinearAlgebra. norm1 (x)
180
- norm1_pullback (Δy) = (NO_FIELDS, _norm1_back (x, y, Δy))
175
+ norm1_pullback (Δy) = (NO_FIELDS, InplaceableThunk (
176
+ @thunk (_norm1_back (x, y, Δy)),
177
+ dx -> _norm1_back! (dx, x, y, Δy),
178
+ ))
181
179
norm1_pullback (:: Zero ) = (NO_FIELDS, Zero ())
182
180
return y, norm1_pullback
183
181
end
@@ -187,6 +185,10 @@ function _norm1_back(x, y, Δy)
187
185
∂x .= sign .(x) .* real (Δy)
188
186
return ∂x
189
187
end
188
+ function _norm1_back! (∂x, x, y, Δy)
189
+ ∂x .+ = sign .(x) .* real (Δy)
190
+ return ∂x
191
+ end
190
192
191
193
# ####
192
194
# #### `norm2`
@@ -197,12 +199,12 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x)
197
199
return y, _norm2_forward (x, Δx, y)
198
200
end
199
201
200
- function rrule (
201
- :: typeof (LinearAlgebra. norm2),
202
- x:: Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal} ,
203
- )
202
+ function rrule (:: typeof (LinearAlgebra. norm2), x:: AbstractArray )
204
203
y = LinearAlgebra. norm2 (x)
205
- norm2_pullback (Δy) = (NO_FIELDS, _norm2_back (x, y, Δy))
204
+ norm2_pullback (Δy) = (NO_FIELDS, InplaceableThunk (
205
+ @thunk (_norm2_back (x, y, Δy)),
206
+ dx -> _norm2_back! (dx, x, y, Δy),
207
+ ))
206
208
norm2_pullback (:: Zero ) = (NO_FIELDS, Zero ())
207
209
return y, norm2_pullback
208
210
end
@@ -216,6 +218,10 @@ function _norm2_back(x, y, Δy)
216
218
∂x .= x .* (real (Δy) * pinv (y))
217
219
return ∂x
218
220
end
221
+ function _norm2_back! (∂x, x, y, Δy)
222
+ ∂x .+ = x .* (real (Δy) * pinv (y))
223
+ return ∂x # must return after mutating
224
+ end
219
225
220
226
# ####
221
227
# #### `normalize`
0 commit comments