@@ -108,15 +108,22 @@ end
108
108
109
109
function _normp_back_x (x, p, y, Δy)
110
110
c = real (Δy) / y
111
- T = promote_type (eltype (x), typeof (c))
112
- ∂x = similar (x, T) # same comment as _norm1_back about allocation and type-stability.
113
- map! (∂x, x) do xi
111
+ ∂x = map (x) do xi
114
112
a = norm (xi)
115
113
∂xi = xi * ((a / y)^ (p - 2 ) * c)
116
114
return ifelse (isfinite (∂xi), ∂xi, zero (∂xi))
117
115
end
118
116
return ∂x
119
117
end
118
+ function _normp_back_x (x:: WithSomeZeros , p, y, Δy) # Diagonal, UpperTriangular, etc.
119
+ c = real (Δy) / y
120
+ ∂x_data = map (parent (x)) do xi
121
+ a = norm (xi)
122
+ ∂xi = xi * ((a / y)^ (p - 2 ) * c)
123
+ return ifelse (isfinite (∂xi), ∂xi, zero (∂xi))
124
+ end
125
+ return withsomezeros_rewrap (x, ∂x_data)
126
+ end
120
127
121
128
function _normp_back_p (x, p, y, Δy)
122
129
y > 0 && isfinite (y) && ! iszero (p) || return zero (real (Δy)) * zero (y) / one (p)
@@ -175,13 +182,13 @@ function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray)
175
182
end
176
183
177
184
function _norm1_back (x, y, Δy)
178
- T = promote_type (eltype (x), real (eltype (Δy)))
179
- ∂x = similar (x, T)
180
- # The reason not to let broadcast allocate ∂x is that NaN .* Diagonal(ones(3)) isa Matrix,
181
- # while pi .* Diagonal(ones(3)) isa Diagonal, hence this would be type-unstable.
182
- ∂x .= sign .(x) .* real (Δy)
185
+ ∂x = sign .(x) .* real (Δy)
183
186
return ∂x
184
187
end
188
+ function _norm1_back (x:: WithSomeZeros , y, Δy)
189
+ ∂x_data = sign .(parent (x)) .* real (Δy)
190
+ return withsomezeros_rewrap (x, ∂x_data)
191
+ end
185
192
function _norm1_back! (∂x, x, y, Δy)
186
193
∂x .+ = sign .(x) .* real (Δy)
187
194
return ∂x
@@ -211,11 +218,14 @@ function _norm2_forward(x, Δx, y)
211
218
return ∂y
212
219
end
213
220
function _norm2_back (x, y, Δy)
214
- T = typeof (one (eltype (x)) / one (real (eltype (Δy))))
215
- ∂x = similar (x, T) # same comment as _norm1_back about allocation and type-stability.
216
- ∂x .= x .* (real (Δy) * pinv (y))
221
+ ∂x = x .* (real (Δy) * pinv (y))
217
222
return ∂x
218
223
end
224
+ function _norm2_back (x:: WithSomeZeros , y, Δy)
225
+ T = typeof (one (eltype (x)) / one (real (eltype (Δy))))
226
+ ∂x_data = parent (x) .* (real (Δy) * pinv (y))
227
+ return withsomezeros_rewrap (x, ∂x_data)
228
+ end
219
229
function _norm2_back! (∂x, x, y, Δy)
220
230
∂x .+ = x .* (real (Δy) * pinv (y))
221
231
return ∂x # must return after mutating
0 commit comments