Skip to content

Commit fa4b93a

Browse files
Add rules for specialized norm functions and normalize (#226)
* Add generic normp * Reorganize * Add frule for normp * Make more readable * Use norm instead of abs * Add rules for normMinusInf * Also cover normInf * Finish early when passed a Zero * Add norm1 rules * Add norm2 rules * Separate pullbacks * Remove whitespace * Add comment * Only compute log if necessary * Simplify logic * Make normInf cotangent one-hot * Constrain rrules to arrays * Add comment * Split pullback functions * Don't broadcast over shared denom * Add frules for norm * Generalize rrule for norm * Reimplement rrule for number norm * Use correct variable name * Release type constraint * Add more special cases * Split forward passes into own functions * don't ignore (co)tangents on p * Add rules for normalize and normalize! * Special-case normalize with p=2 * Add overloads for transpose and adjoint * Generalize by going through norm * Bump ChainRulesCore compat * Don't assume has eltype * Import rand_tangent * Test normalize and normalize! * Don't unnecessarily thunk * Bump required version number To ensure we get TestIterator * Add tests for norm functions * Restrict types for rrules * Move norm functions to their own file * Remove frules for norm * Ensure real multiplied first * Lower precision of test * Revert accidental commit * Remove signatures with default args * Reuse variable * Reorganize normalize tests * Test scalar frule * Test transpose/adjoint rules * Add back in frules for norm2 * Add back rrule for norm no p * Test norm without p * Ensure normalize pulls back Zero * Apply suggestions from code review Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk> * Increase tolerance for infinite norms * Always define kwargs * Make normp pullback more stable for p = +/- inf * Test with higher power * Use norm2 forward for empty x * Combine checks * Test norm2 frule * Test norm for empty array * Test structured matrices * Update Project.toml Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
1 parent 6eb1027 commit fa4b93a

File tree

7 files changed

+411
-35
lines changed

7 files changed

+411
-35
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.34"
3+
version = "0.7.35"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -14,7 +14,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1414

1515
[compat]
1616
ChainRulesCore = "0.9.21"
17-
ChainRulesTestUtils = "0.5"
17+
ChainRulesTestUtils = "0.5.1"
1818
Compat = "3"
1919
FiniteDifferences = "0.11.4"
2020
Reexport = "0.2"

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ include("rulesets/Statistics/statistics.jl")
4343
include("rulesets/LinearAlgebra/utils.jl")
4444
include("rulesets/LinearAlgebra/blas.jl")
4545
include("rulesets/LinearAlgebra/dense.jl")
46+
include("rulesets/LinearAlgebra/norm.jl")
4647
include("rulesets/LinearAlgebra/structured.jl")
4748
include("rulesets/LinearAlgebra/factorization.jl")
4849

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -239,27 +239,3 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T}
239239
end
240240
return Y, pinv_pullback
241241
end
242-
243-
#####
244-
##### `norm`
245-
#####
246-
247-
function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2)
248-
y = norm(A, p)
249-
function norm_pullback(ȳ)
250-
u = y^(1-p)
251-
∂A = @thunk.* u .* abs.(A).^p ./ A
252-
∂p = @thunk* (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p
253-
(NO_FIELDS, ∂A, ∂p)
254-
end
255-
return y, norm_pullback
256-
end
257-
258-
function rrule(::typeof(norm), x::Real, p::Real=2)
259-
function norm_pullback(ȳ)
260-
∂x = @thunk* sign(x)
261-
∂p = @thunk zero(x) # TODO: should this be Zero()?
262-
(NO_FIELDS, ∂x, ∂p)
263-
end
264-
return norm(x, p), norm_pullback
265-
end

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
#####
2+
##### `norm`
3+
#####
4+
5+
function frule((_, Δx), ::typeof(norm), x)
6+
y = norm(x)
7+
return y, _norm2_forward(x, Δx, norm(x))
8+
end
9+
function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
10+
y = norm(x, p)
11+
∂y = if iszero(Δx) || iszero(p)
12+
zero(real(x)) * zero(real(Δx))
13+
else
14+
signx = x isa Real ? sign(x) : x * pinv(y)
15+
_realconjtimes(signx, Δx)
16+
end
17+
return y, ∂y
18+
end
19+
20+
function rrule(
21+
::typeof(norm),
22+
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
23+
p::Real,
24+
)
25+
y = LinearAlgebra.norm(x, p)
26+
function norm_pullback(Δy)
27+
∂x = Thunk() do
28+
return if isempty(x) || p == 0
29+
zero.(x) .* (zero(y) * zero(real(Δy)))
30+
elseif p == 2
31+
_norm2_back(x, y, Δy)
32+
elseif p == 1
33+
_norm1_back(x, y, Δy)
34+
elseif p == Inf
35+
_normInf_back(x, y, Δy)
36+
elseif p == -Inf
37+
_normInf_back(x, y, Δy)
38+
else
39+
_normp_back_x(x, p, y, Δy)
40+
end
41+
end
42+
∂p = @thunk _normp_back_p(x, p, y, Δy)
43+
return (NO_FIELDS, ∂x, ∂p)
44+
end
45+
norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
46+
return y, norm_pullback
47+
end
48+
function rrule(
49+
::typeof(norm),
50+
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
51+
)
52+
y = LinearAlgebra.norm(x)
53+
function norm_pullback(Δy)
54+
∂x = if isempty(x)
55+
zero.(x) .* (zero(y) * zero(real(Δy)))
56+
else
57+
_norm2_back(x, y, Δy)
58+
end
59+
return (NO_FIELDS, ∂x)
60+
end
61+
norm_pullback(::Zero) = (NO_FIELDS, Zero())
62+
return y, norm_pullback
63+
end
64+
function rrule(
65+
::typeof(norm),
66+
x::Union{LinearAlgebra.TransposeAbsVec, LinearAlgebra.AdjointAbsVec},
67+
p::Real,
68+
)
69+
y, inner_pullback = rrule(norm, parent(x), p)
70+
function norm_pullback(Δy)
71+
(∂self, ∂x′, ∂p) = inner_pullback(Δy)
72+
fdual = x isa Transpose ? transpose : adjoint
73+
∂x = @thunk fdual(unthunk(∂x′))
74+
return (∂self, ∂x, ∂p)
75+
end
76+
return y, norm_pullback
77+
end
78+
function rrule(::typeof(norm), x::Number, p::Real)
79+
y = norm(x, p)
80+
function norm_pullback(Δy)
81+
∂x = if iszero(Δy) || iszero(p)
82+
zero(x) * zero(real(Δy))
83+
else
84+
signx = x isa Real ? sign(x) : x * pinv(y)
85+
signx * real(Δy)
86+
end
87+
return (NO_FIELDS, ∂x, Zero())
88+
end
89+
norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
90+
return y, norm_pullback
91+
end
92+
93+
#####
94+
##### `normp`
95+
#####
96+
97+
function rrule(
98+
::typeof(LinearAlgebra.normp),
99+
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
100+
p,
101+
)
102+
y = LinearAlgebra.normp(x, p)
103+
function normp_pullback(Δy)
104+
∂x = @thunk _normp_back_x(x, p, y, Δy)
105+
∂p = @thunk _normp_back_p(x, p, y, Δy)
106+
return (NO_FIELDS, ∂x, ∂p)
107+
end
108+
normp_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
109+
return y, normp_pullback
110+
end
111+
112+
function _normp_back_x(x, p, y, Δy)
113+
c = real(Δy) / y
114+
∂x = broadcast(x) do xi
115+
a = norm(xi)
116+
∂xi = xi * ((a / y)^(p - 2) * c)
117+
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
118+
end
119+
return ∂x
120+
end
121+
122+
function _normp_back_p(x, p, y, Δy)
123+
y > 0 && isfinite(y) && !iszero(p) || return zero(real(Δy)) * zero(y) / one(p)
124+
s = sum(x) do xi
125+
a = norm(xi)
126+
c = (a / y)^(p - 1) * a * log(a)
127+
return ifelse(isfinite(c), c, zero(c))
128+
end
129+
∂p = real(Δy) * (s - y * log(y)) / p
130+
return ∂p
131+
end
132+
133+
#####
134+
##### `normMinusInf`/`normInf`
135+
#####
136+
137+
function rrule(
138+
::typeof(LinearAlgebra.normMinusInf),
139+
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
140+
)
141+
y = LinearAlgebra.normMinusInf(x)
142+
normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
143+
normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero())
144+
return y, normMinusInf_pullback
145+
end
146+
147+
function rrule(
148+
::typeof(LinearAlgebra.normInf),
149+
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
150+
)
151+
y = LinearAlgebra.normInf(x)
152+
normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
153+
normInf_pullback(::Zero) = (NO_FIELDS, Zero())
154+
return y, normInf_pullback
155+
end
156+
157+
function _normInf_back(x, y, Δy)
158+
Δu = real(Δy)
159+
T = typeof(zero(float(eltype(x))) * zero(Δu))
160+
∂x = fill!(similar(x, T), 0)
161+
# if multiple `xi`s have the exact same norm, then they must have been identically
162+
# produced, e.g. with `fill`. So we set only one to be non-zero.
163+
# we choose last index to match the `frule`.
164+
yind = findlast(xi -> norm(xi) == y, x)
165+
yind === nothing && throw(ArgumentError("y is not the correct norm of x"))
166+
@inbounds ∂x[yind] = sign(x[yind]) * Δu
167+
return ∂x
168+
end
169+
170+
#####
171+
##### `norm1`
172+
#####
173+
174+
function rrule(
175+
::typeof(LinearAlgebra.norm1),
176+
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
177+
)
178+
y = LinearAlgebra.norm1(x)
179+
norm1_pullback(Δy) = (NO_FIELDS, _norm1_back(x, y, Δy))
180+
norm1_pullback(::Zero) = (NO_FIELDS, Zero())
181+
return y, norm1_pullback
182+
end
183+
184+
_norm1_back(x, y, Δy) = sign.(x) .* real(Δy)
185+
186+
#####
187+
##### `norm2`
188+
#####
189+
190+
function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x)
191+
y = LinearAlgebra.norm2(x)
192+
return y, _norm2_forward(x, Δx, y)
193+
end
194+
195+
function rrule(
196+
::typeof(LinearAlgebra.norm2),
197+
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
198+
)
199+
y = LinearAlgebra.norm2(x)
200+
norm2_pullback(Δy) = (NO_FIELDS, _norm2_back(x, y, Δy))
201+
norm2_pullback(::Zero) = (NO_FIELDS, Zero())
202+
return y, norm2_pullback
203+
end
204+
205+
function _norm2_forward(x, Δx, y)
206+
∂y = real(dot(x, Δx)) * pinv(y)
207+
return ∂y
208+
end
209+
_norm2_back(x, y, Δy) = x .* (real(Δy) * pinv(y))
210+
211+
#####
212+
##### `normalize`
213+
#####
214+
215+
function rrule(::typeof(normalize), x::AbstractVector, p::Real)
216+
nrm, inner_pullback = rrule(norm, x, p)
217+
Ty = typeof(first(x) / nrm)
218+
y = copyto!(similar(x, Ty), x)
219+
LinearAlgebra.__normalize!(y, nrm)
220+
function normalize_pullback(Δy)
221+
invnrm = pinv(nrm)
222+
∂nrm = -dot(y, Δy) * invnrm
223+
(_, ∂xnorm, ∂p) = inner_pullback(∂nrm)
224+
∂x = @thunk unthunk(∂xnorm) .+ Δy .* invnrm
225+
return (NO_FIELDS, ∂x, ∂p)
226+
end
227+
normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
228+
return y, normalize_pullback
229+
end
230+
function rrule(::typeof(normalize), x::AbstractVector)
231+
nrm = LinearAlgebra.norm2(x)
232+
Ty = typeof(first(x) / nrm)
233+
y = copyto!(similar(x, Ty), x)
234+
LinearAlgebra.__normalize!(y, nrm)
235+
function normalize_pullback(Δy)
236+
∂x = (Δy .- real(dot(y, Δy)) .* y) .* pinv(nrm)
237+
return (NO_FIELDS, ∂x)
238+
end
239+
normalize_pullback(::Zero) = (NO_FIELDS, Zero())
240+
return y, normalize_pullback
241+
end

test/rulesets/LinearAlgebra/dense.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,4 @@
130130
frule_test(tr, (randn(N, N), randn(N, N)))
131131
rrule_test(tr, randn(), (randn(N, N), randn(N, N)))
132132
end
133-
@testset "norm" begin
134-
for dims in [(), (5,), (3, 2), (7, 3, 2)]
135-
A = randn(dims...)
136-
p = randn()
137-
= randn()
138-
rrule_test(norm, ȳ, (A, randn(dims...)), (p, randn()))
139-
end
140-
end
141133
end

0 commit comments

Comments
 (0)