@@ -6,6 +6,7 @@ function frule((_, Δx), ::typeof(norm), x)
6
6
y = norm (x)
7
7
return y, _norm2_forward (x, Δx, norm (x))
8
8
end
9
+
9
10
function frule ((_, Δx), :: typeof (norm), x:: Number , p:: Real )
10
11
y = norm (x, p)
11
12
∂y = if iszero (Δx) || iszero (p)
@@ -17,15 +18,12 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
17
18
return y, ∂y
18
19
end
19
20
20
- function rrule (
21
- :: typeof (norm),
22
- x:: Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal} ,
23
- p:: Real ,
24
- )
21
+ function rrule (:: typeof (norm), x:: AbstractArray{<:Number} , p:: Real )
25
22
y = LinearAlgebra. norm (x, p)
26
- function norm_pullback (Δy)
27
- ∂x = Thunk () do
28
- return if isempty (x) || p == 0
23
+ function norm_pullback_p (Δy)
24
+ ∂x = InplaceableThunk (
25
+ # out-of-place versions
26
+ @thunk (if isempty (x) || p == 0
29
27
zero .(x) .* (zero (y) * zero (real (Δy)))
30
28
elseif p == 2
31
29
_norm2_back (x, y, Δy)
@@ -37,35 +35,52 @@ function rrule(
37
35
_normInf_back (x, y, Δy)
38
36
else
39
37
_normp_back_x (x, p, y, Δy)
38
+ end )
39
+ , # in-place versions -- can be fixed when actually useful?
40
+ dx -> if isempty (x) || p == 0
41
+ dx
42
+ elseif p == 2
43
+ _norm2_back! (dx, x, y, Δy)
44
+ elseif p == 1
45
+ _norm1_back! (dx, x, y, Δy)
46
+ elseif p == Inf
47
+ dx .+ = _normInf_back (x, y, Δy) # not really in-place! could perhaps be improved
48
+ elseif p == - Inf
49
+ dx .+ = _normInf_back (x, y, Δy)
50
+ else
51
+ dx .+ = _normp_back_x (x, p, y, Δy)
40
52
end
41
- end
53
+ )
42
54
∂p = @thunk _normp_back_p (x, p, y, Δy)
43
55
return (NO_FIELDS, ∂x, ∂p)
44
56
end
45
- norm_pullback (:: Zero ) = (NO_FIELDS, Zero (), Zero ())
46
- return y, norm_pullback
57
+ norm_pullback_p (:: Zero ) = (NO_FIELDS, Zero (), Zero ())
58
+ return y, norm_pullback_p
47
59
end
48
- function rrule (
49
- :: typeof (norm),
50
- x:: Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal} ,
51
- )
60
+
61
+ function rrule (:: typeof (norm), x:: AbstractArray{<:Number} )
52
62
y = LinearAlgebra. norm (x)
53
- function norm_pullback (Δy)
54
- ∂x = if isempty (x)
55
- zero .(x) .* (zero (y) * zero (real (Δy)))
56
- else
57
- _norm2_back (x, y, Δy)
58
- end
63
+ function norm_pullback_2 (Δy)
64
+ ∂x = InplaceableThunk (
65
+ @thunk (if isempty (x)
66
+ zero .(x) .* (zero (y) * zero (real (Δy)))
67
+ else
68
+ _norm2_back (x, y, Δy)
69
+ end )
70
+ ,
71
+ dx -> if isempty (x)
72
+ dx
73
+ else
74
+ _norm2_back! (dx, x, y, Δy)
75
+ end
76
+ )
59
77
return (NO_FIELDS, ∂x)
60
78
end
61
- norm_pullback (:: Zero ) = (NO_FIELDS, Zero ())
62
- return y, norm_pullback
79
+ norm_pullback_2 (:: Zero ) = (NO_FIELDS, Zero ())
80
+ return y, norm_pullback_2
63
81
end
64
- function rrule (
65
- :: typeof (norm),
66
- x:: Union{LinearAlgebra.TransposeAbsVec, LinearAlgebra.AdjointAbsVec} ,
67
- p:: Real ,
68
- )
82
+
83
+ function rrule (:: typeof (norm), x:: LinearAlgebra.AdjOrTransAbsVec{<:Number} , p:: Real )
69
84
y, inner_pullback = rrule (norm, parent (x), p)
70
85
function norm_pullback (Δy)
71
86
(∂self, ∂x′, ∂p) = inner_pullback (Δy)
@@ -75,6 +90,7 @@ function rrule(
75
90
end
76
91
return y, norm_pullback
77
92
end
93
+
78
94
function rrule (:: typeof (norm), x:: Number , p:: Real )
79
95
y = norm (x, p)
80
96
function norm_pullback (Δy)
94
110
# #### `normp`
95
111
# ####
96
112
97
- function rrule (
98
- :: typeof (LinearAlgebra. normp),
99
- x:: Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal} ,
100
- p,
101
- )
113
+ function rrule (:: typeof (LinearAlgebra. normp), x:: AbstractArray{<:Number} , p)
102
114
y = LinearAlgebra. normp (x, p)
103
115
function normp_pullback (Δy)
104
116
∂x = @thunk _normp_back_x (x, p, y, Δy)
@@ -111,15 +123,24 @@ end
111
123
112
124
function _normp_back_x (x, p, y, Δy)
113
125
c = real (Δy) / y
114
- ∂x = similar (x)
115
- broadcast! (∂x, x) do xi
126
+ ∂x = map (x) do xi
116
127
a = norm (xi)
117
128
∂xi = xi * ((a / y)^ (p - 2 ) * c)
118
129
return ifelse (isfinite (∂xi), ∂xi, zero (∂xi))
119
130
end
120
131
return ∂x
121
132
end
122
133
134
+ function _normp_back_x (x:: WithSomeZeros , p, y, Δy) # Diagonal, UpperTriangular, etc.
135
+ c = real (Δy) / y
136
+ ∂x_data = map (parent (x)) do xi
137
+ a = norm (xi)
138
+ ∂xi = xi * ((a / y)^ (p - 2 ) * c)
139
+ return ifelse (isfinite (∂xi), ∂xi, zero (∂xi))
140
+ end
141
+ return withsomezeros_rewrap (x, ∂x_data)
142
+ end
143
+
123
144
function _normp_back_p (x, p, y, Δy)
124
145
y > 0 && isfinite (y) && ! iszero (p) || return zero (real (Δy)) * zero (y) / one (p)
125
146
s = sum (x) do xi
@@ -135,20 +156,14 @@ end
135
156
# #### `normMinusInf`/`normInf`
136
157
# ####
137
158
138
- function rrule (
139
- :: typeof (LinearAlgebra. normMinusInf),
140
- x:: Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal} ,
141
- )
159
+ function rrule (:: typeof (LinearAlgebra. normMinusInf), x:: AbstractArray{<:Number} )
142
160
y = LinearAlgebra. normMinusInf (x)
143
161
normMinusInf_pullback (Δy) = (NO_FIELDS, _normInf_back (x, y, Δy))
144
162
normMinusInf_pullback (:: Zero ) = (NO_FIELDS, Zero ())
145
163
return y, normMinusInf_pullback
146
164
end
147
165
148
- function rrule (
149
- :: typeof (LinearAlgebra. normInf),
150
- x:: Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal} ,
151
- )
166
+ function rrule (:: typeof (LinearAlgebra. normInf), x:: AbstractArray{<:Number} )
152
167
y = LinearAlgebra. normInf (x)
153
168
normInf_pullback (Δy) = (NO_FIELDS, _normInf_back (x, y, Δy))
154
169
normInf_pullback (:: Zero ) = (NO_FIELDS, Zero ())
@@ -172,19 +187,26 @@ end
172
187
# #### `norm1`
173
188
# ####
174
189
175
- function rrule (
176
- :: typeof (LinearAlgebra. norm1),
177
- x:: Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal} ,
178
- )
190
+ function rrule (:: typeof (LinearAlgebra. norm1), x:: AbstractArray{<:Number} )
179
191
y = LinearAlgebra. norm1 (x)
180
- norm1_pullback (Δy) = (NO_FIELDS, _norm1_back (x, y, Δy))
192
+ norm1_pullback (Δy) = (NO_FIELDS, InplaceableThunk (
193
+ @thunk (_norm1_back (x, y, Δy)),
194
+ dx -> _norm1_back! (dx, x, y, Δy),
195
+ ))
181
196
norm1_pullback (:: Zero ) = (NO_FIELDS, Zero ())
182
197
return y, norm1_pullback
183
198
end
184
199
185
200
function _norm1_back (x, y, Δy)
186
- ∂x = similar (x)
187
- ∂x .= sign .(x) .* real (Δy)
201
+ ∂x = sign .(x) .* real (Δy)
202
+ return ∂x
203
+ end
204
+ function _norm1_back (x:: WithSomeZeros , y, Δy)
205
+ ∂x_data = sign .(parent (x)) .* real (Δy)
206
+ return withsomezeros_rewrap (x, ∂x_data)
207
+ end
208
+ function _norm1_back! (∂x, x, y, Δy)
209
+ ∂x .+ = sign .(x) .* real (Δy)
188
210
return ∂x
189
211
end
190
212
@@ -197,12 +219,12 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x)
197
219
return y, _norm2_forward (x, Δx, y)
198
220
end
199
221
200
- function rrule (
201
- :: typeof (LinearAlgebra. norm2),
202
- x:: Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal} ,
203
- )
222
+ function rrule (:: typeof (LinearAlgebra. norm2), x:: AbstractArray{<:Number} )
204
223
y = LinearAlgebra. norm2 (x)
205
- norm2_pullback (Δy) = (NO_FIELDS, _norm2_back (x, y, Δy))
224
+ norm2_pullback (Δy) = (NO_FIELDS, InplaceableThunk (
225
+ @thunk (_norm2_back (x, y, Δy)),
226
+ dx -> _norm2_back! (dx, x, y, Δy),
227
+ ))
206
228
norm2_pullback (:: Zero ) = (NO_FIELDS, Zero ())
207
229
return y, norm2_pullback
208
230
end
@@ -212,16 +234,24 @@ function _norm2_forward(x, Δx, y)
212
234
return ∂y
213
235
end
214
236
function _norm2_back (x, y, Δy)
215
- ∂x = similar (x)
216
- ∂x .= x .* (real (Δy) * pinv (y))
237
+ ∂x = x .* (real (Δy) * pinv (y))
217
238
return ∂x
218
239
end
240
+ function _norm2_back (x:: WithSomeZeros , y, Δy)
241
+ T = typeof (one (eltype (x)) / one (real (eltype (Δy))))
242
+ ∂x_data = parent (x) .* (real (Δy) * pinv (y))
243
+ return withsomezeros_rewrap (x, ∂x_data)
244
+ end
245
+ function _norm2_back! (∂x, x, y, Δy)
246
+ ∂x .+ = x .* (real (Δy) * pinv (y))
247
+ return ∂x # must return after mutating
248
+ end
219
249
220
250
# ####
221
251
# #### `normalize`
222
252
# ####
223
253
224
- function rrule (:: typeof (normalize), x:: AbstractVector , p:: Real )
254
+ function rrule (:: typeof (normalize), x:: AbstractVector{<:Number} , p:: Real )
225
255
nrm, inner_pullback = rrule (norm, x, p)
226
256
Ty = typeof (first (x) / nrm)
227
257
y = copyto! (similar (x, Ty), x)
@@ -236,7 +266,8 @@ function rrule(::typeof(normalize), x::AbstractVector, p::Real)
236
266
normalize_pullback (:: Zero ) = (NO_FIELDS, Zero (), Zero ())
237
267
return y, normalize_pullback
238
268
end
239
- function rrule (:: typeof (normalize), x:: AbstractVector )
269
+
270
+ function rrule (:: typeof (normalize), x:: AbstractVector{<:Number} )
240
271
nrm = LinearAlgebra. norm2 (x)
241
272
Ty = typeof (first (x) / nrm)
242
273
y = copyto! (similar (x, Ty), x)
0 commit comments