Skip to content

Commit db07b7a

Browse files
authored
Merge pull request #337 from mcabbott/norm
Improvements to rules for `norm`
2 parents 987ee45 + 87e4313 commit db07b7a

File tree

4 files changed

+180
-71
lines changed

4 files changed

+180
-71
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.63"
3+
version = "0.7.64"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 90 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ function frule((_, Δx), ::typeof(norm), x)
66
y = norm(x)
77
return y, _norm2_forward(x, Δx, norm(x))
88
end
9+
910
function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
1011
y = norm(x, p)
1112
∂y = if iszero(Δx) || iszero(p)
@@ -17,15 +18,12 @@ function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
1718
return y, ∂y
1819
end
1920

20-
function rrule(
21-
::typeof(norm),
22-
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
23-
p::Real,
24-
)
21+
function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
2522
y = LinearAlgebra.norm(x, p)
26-
function norm_pullback(Δy)
27-
∂x = Thunk() do
28-
return if isempty(x) || p == 0
23+
function norm_pullback_p(Δy)
24+
∂x = InplaceableThunk(
25+
# out-of-place versions
26+
@thunk(if isempty(x) || p == 0
2927
zero.(x) .* (zero(y) * zero(real(Δy)))
3028
elseif p == 2
3129
_norm2_back(x, y, Δy)
@@ -37,35 +35,52 @@ function rrule(
3735
_normInf_back(x, y, Δy)
3836
else
3937
_normp_back_x(x, p, y, Δy)
38+
end)
39+
, # in-place versions -- can be fixed when actually useful?
40+
dx -> if isempty(x) || p == 0
41+
dx
42+
elseif p == 2
43+
_norm2_back!(dx, x, y, Δy)
44+
elseif p == 1
45+
_norm1_back!(dx, x, y, Δy)
46+
elseif p == Inf
47+
dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved
48+
elseif p == -Inf
49+
dx .+= _normInf_back(x, y, Δy)
50+
else
51+
dx .+= _normp_back_x(x, p, y, Δy)
4052
end
41-
end
53+
)
4254
∂p = @thunk _normp_back_p(x, p, y, Δy)
4355
return (NO_FIELDS, ∂x, ∂p)
4456
end
45-
norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
46-
return y, norm_pullback
57+
norm_pullback_p(::Zero) = (NO_FIELDS, Zero(), Zero())
58+
return y, norm_pullback_p
4759
end
48-
function rrule(
49-
::typeof(norm),
50-
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
51-
)
60+
61+
function rrule(::typeof(norm), x::AbstractArray{<:Number})
5262
y = LinearAlgebra.norm(x)
53-
function norm_pullback(Δy)
54-
∂x = if isempty(x)
55-
zero.(x) .* (zero(y) * zero(real(Δy)))
56-
else
57-
_norm2_back(x, y, Δy)
58-
end
63+
function norm_pullback_2(Δy)
64+
∂x = InplaceableThunk(
65+
@thunk(if isempty(x)
66+
zero.(x) .* (zero(y) * zero(real(Δy)))
67+
else
68+
_norm2_back(x, y, Δy)
69+
end)
70+
,
71+
dx -> if isempty(x)
72+
dx
73+
else
74+
_norm2_back!(dx, x, y, Δy)
75+
end
76+
)
5977
return (NO_FIELDS, ∂x)
6078
end
61-
norm_pullback(::Zero) = (NO_FIELDS, Zero())
62-
return y, norm_pullback
79+
norm_pullback_2(::Zero) = (NO_FIELDS, Zero())
80+
return y, norm_pullback_2
6381
end
64-
function rrule(
65-
::typeof(norm),
66-
x::Union{LinearAlgebra.TransposeAbsVec, LinearAlgebra.AdjointAbsVec},
67-
p::Real,
68-
)
82+
83+
function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::Real)
6984
y, inner_pullback = rrule(norm, parent(x), p)
7085
function norm_pullback(Δy)
7186
(∂self, ∂x′, ∂p) = inner_pullback(Δy)
@@ -75,6 +90,7 @@ function rrule(
7590
end
7691
return y, norm_pullback
7792
end
93+
7894
function rrule(::typeof(norm), x::Number, p::Real)
7995
y = norm(x, p)
8096
function norm_pullback(Δy)
@@ -94,11 +110,7 @@ end
94110
##### `normp`
95111
#####
96112

97-
function rrule(
98-
::typeof(LinearAlgebra.normp),
99-
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
100-
p,
101-
)
113+
function rrule(::typeof(LinearAlgebra.normp), x::AbstractArray{<:Number}, p)
102114
y = LinearAlgebra.normp(x, p)
103115
function normp_pullback(Δy)
104116
∂x = @thunk _normp_back_x(x, p, y, Δy)
@@ -111,15 +123,24 @@ end
111123

112124
function _normp_back_x(x, p, y, Δy)
113125
c = real(Δy) / y
114-
∂x = similar(x)
115-
broadcast!(∂x, x) do xi
126+
∂x = map(x) do xi
116127
a = norm(xi)
117128
∂xi = xi * ((a / y)^(p - 2) * c)
118129
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
119130
end
120131
return ∂x
121132
end
122133

134+
function _normp_back_x(x::WithSomeZeros, p, y, Δy) # Diagonal, UpperTriangular, etc.
135+
c = real(Δy) / y
136+
∂x_data = map(parent(x)) do xi
137+
a = norm(xi)
138+
∂xi = xi * ((a / y)^(p - 2) * c)
139+
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
140+
end
141+
return withsomezeros_rewrap(x, ∂x_data)
142+
end
143+
123144
function _normp_back_p(x, p, y, Δy)
124145
y > 0 && isfinite(y) && !iszero(p) || return zero(real(Δy)) * zero(y) / one(p)
125146
s = sum(x) do xi
@@ -135,20 +156,14 @@ end
135156
##### `normMinusInf`/`normInf`
136157
#####
137158

138-
function rrule(
139-
::typeof(LinearAlgebra.normMinusInf),
140-
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
141-
)
159+
function rrule(::typeof(LinearAlgebra.normMinusInf), x::AbstractArray{<:Number})
142160
y = LinearAlgebra.normMinusInf(x)
143161
normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
144162
normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero())
145163
return y, normMinusInf_pullback
146164
end
147165

148-
function rrule(
149-
::typeof(LinearAlgebra.normInf),
150-
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
151-
)
166+
function rrule(::typeof(LinearAlgebra.normInf), x::AbstractArray{<:Number})
152167
y = LinearAlgebra.normInf(x)
153168
normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
154169
normInf_pullback(::Zero) = (NO_FIELDS, Zero())
@@ -172,19 +187,26 @@ end
172187
##### `norm1`
173188
#####
174189

175-
function rrule(
176-
::typeof(LinearAlgebra.norm1),
177-
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
178-
)
190+
function rrule(::typeof(LinearAlgebra.norm1), x::AbstractArray{<:Number})
179191
y = LinearAlgebra.norm1(x)
180-
norm1_pullback(Δy) = (NO_FIELDS, _norm1_back(x, y, Δy))
192+
norm1_pullback(Δy) = (NO_FIELDS, InplaceableThunk(
193+
@thunk(_norm1_back(x, y, Δy)),
194+
dx -> _norm1_back!(dx, x, y, Δy),
195+
))
181196
norm1_pullback(::Zero) = (NO_FIELDS, Zero())
182197
return y, norm1_pullback
183198
end
184199

185200
function _norm1_back(x, y, Δy)
186-
∂x = similar(x)
187-
∂x .= sign.(x) .* real(Δy)
201+
∂x = sign.(x) .* real(Δy)
202+
return ∂x
203+
end
204+
function _norm1_back(x::WithSomeZeros, y, Δy)
205+
∂x_data = sign.(parent(x)) .* real(Δy)
206+
return withsomezeros_rewrap(x, ∂x_data)
207+
end
208+
function _norm1_back!(∂x, x, y, Δy)
209+
∂x .+= sign.(x) .* real(Δy)
188210
return ∂x
189211
end
190212

@@ -197,12 +219,12 @@ function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x)
197219
return y, _norm2_forward(x, Δx, y)
198220
end
199221

200-
function rrule(
201-
::typeof(LinearAlgebra.norm2),
202-
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
203-
)
222+
function rrule(::typeof(LinearAlgebra.norm2), x::AbstractArray{<:Number})
204223
y = LinearAlgebra.norm2(x)
205-
norm2_pullback(Δy) = (NO_FIELDS, _norm2_back(x, y, Δy))
224+
norm2_pullback(Δy) = (NO_FIELDS, InplaceableThunk(
225+
@thunk(_norm2_back(x, y, Δy)),
226+
dx -> _norm2_back!(dx, x, y, Δy),
227+
))
206228
norm2_pullback(::Zero) = (NO_FIELDS, Zero())
207229
return y, norm2_pullback
208230
end
@@ -212,16 +234,24 @@ function _norm2_forward(x, Δx, y)
212234
return ∂y
213235
end
214236
function _norm2_back(x, y, Δy)
215-
∂x = similar(x)
216-
∂x .= x .* (real(Δy) * pinv(y))
237+
∂x = x .* (real(Δy) * pinv(y))
217238
return ∂x
218239
end
240+
function _norm2_back(x::WithSomeZeros, y, Δy)
241+
T = typeof(one(eltype(x)) / one(real(eltype(Δy))))
242+
∂x_data = parent(x) .* (real(Δy) * pinv(y))
243+
return withsomezeros_rewrap(x, ∂x_data)
244+
end
245+
function _norm2_back!(∂x, x, y, Δy)
246+
∂x .+= x .* (real(Δy) * pinv(y))
247+
return ∂x # must return after mutating
248+
end
219249

220250
#####
221251
##### `normalize`
222252
#####
223253

224-
function rrule(::typeof(normalize), x::AbstractVector, p::Real)
254+
function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real)
225255
nrm, inner_pullback = rrule(norm, x, p)
226256
Ty = typeof(first(x) / nrm)
227257
y = copyto!(similar(x, Ty), x)
@@ -236,7 +266,8 @@ function rrule(::typeof(normalize), x::AbstractVector, p::Real)
236266
normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
237267
return y, normalize_pullback
238268
end
239-
function rrule(::typeof(normalize), x::AbstractVector)
269+
270+
function rrule(::typeof(normalize), x::AbstractVector{<:Number})
240271
nrm = LinearAlgebra.norm2(x)
241272
Ty = typeof(first(x) / nrm)
242273
y = copyto!(similar(x, Ty), x)

src/rulesets/LinearAlgebra/utils.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,36 @@ Symmetric
4343
````
4444
"""
4545
_unionall_wrapper(::Type{T}) where {T} = T.name.wrapper
46+
47+
"""
48+
WithSomeZeros{T}
49+
50+
This is a union of LinearAlgebra types, all of which are partly structral zeros,
51+
with a simple backing array given by `parent(x)`. All have methods of `_rewrap`
52+
to re-create.
53+
54+
This exists to solve a type instability, as broadcasting for instance
55+
`λ .* Diagonal(rand(3))` gives a dense matrix when `x==Inf`.
56+
But `withsomezeros_rewrap(x, λ .* parent(x))` is type-stable.
57+
"""
58+
WithSomeZeros{T} = Union{
59+
Diagonal{T},
60+
UpperTriangular{T},
61+
UnitUpperTriangular{T},
62+
# UpperHessenberg{T}, # doesn't exist in Julia 1.0
63+
LowerTriangular{T},
64+
UnitLowerTriangular{T},
65+
}
66+
for S in [
67+
:Diagonal,
68+
:UpperTriangular,
69+
:UnitUpperTriangular,
70+
# :UpperHessenberg,
71+
:LowerTriangular,
72+
:UnitLowerTriangular,
73+
]
74+
@eval withsomezeros_rewrap(::$S, x) = $S(x)
75+
end
76+
77+
# Bidiagonal, Tridiagonal have more complicated storage.
78+
# AdjOrTransUpperOrUnitUpperTriangular would need adjoint(parent(parent()))

0 commit comments

Comments
 (0)