Skip to content

Commit 889ed20

Browse files
authored
Merge branch 'master' into patch-1
2 parents bb53229 + eb10848 commit 889ed20

File tree

11 files changed

+536
-157
lines changed

11 files changed

+536
-157
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.35"
3+
version = "0.7.37"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -14,7 +14,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1414

1515
[compat]
1616
ChainRulesCore = "0.9.21"
17-
ChainRulesTestUtils = "0.5"
17+
ChainRulesTestUtils = "0.5.1"
1818
Compat = "3"
1919
FiniteDifferences = "0.11.4"
2020
Reexport = "0.2"

src/ChainRules.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ include("rulesets/Statistics/statistics.jl")
4343
include("rulesets/LinearAlgebra/utils.jl")
4444
include("rulesets/LinearAlgebra/blas.jl")
4545
include("rulesets/LinearAlgebra/dense.jl")
46+
include("rulesets/LinearAlgebra/norm.jl")
4647
include("rulesets/LinearAlgebra/structured.jl")
48+
include("rulesets/LinearAlgebra/symmetric.jl")
4749
include("rulesets/LinearAlgebra/factorization.jl")
4850

4951
include("rulesets/Random/random.jl")

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -239,27 +239,3 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T}
239239
end
240240
return Y, pinv_pullback
241241
end
242-
243-
#####
244-
##### `norm`
245-
#####
246-
247-
function rrule(::typeof(norm), A::AbstractArray{<:Real}, p::Real=2)
248-
y = norm(A, p)
249-
function norm_pullback(ȳ)
250-
u = y^(1-p)
251-
∂A = @thunk.* u .* abs.(A).^p ./ A
252-
∂p = @thunk* (u * sum(a->abs(a)^p * log(abs(a)), A) - y * log(y)) / p
253-
(NO_FIELDS, ∂A, ∂p)
254-
end
255-
return y, norm_pullback
256-
end
257-
258-
function rrule(::typeof(norm), x::Real, p::Real=2)
259-
function norm_pullback(ȳ)
260-
∂x = @thunk* sign(x)
261-
∂p = @thunk zero(x) # TODO: should this be Zero()?
262-
(NO_FIELDS, ∂x, ∂p)
263-
end
264-
return norm(x, p), norm_pullback
265-
end

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
#####
2+
##### `norm`
3+
#####
4+
5+
function frule((_, Δx), ::typeof(norm), x)
6+
y = norm(x)
7+
return y, _norm2_forward(x, Δx, norm(x))
8+
end
9+
function frule((_, Δx), ::typeof(norm), x::Number, p::Real)
10+
y = norm(x, p)
11+
∂y = if iszero(Δx) || iszero(p)
12+
zero(real(x)) * zero(real(Δx))
13+
else
14+
signx = x isa Real ? sign(x) : x * pinv(y)
15+
_realconjtimes(signx, Δx)
16+
end
17+
return y, ∂y
18+
end
19+
20+
function rrule(
21+
::typeof(norm),
22+
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
23+
p::Real,
24+
)
25+
y = LinearAlgebra.norm(x, p)
26+
function norm_pullback(Δy)
27+
∂x = Thunk() do
28+
return if isempty(x) || p == 0
29+
zero.(x) .* (zero(y) * zero(real(Δy)))
30+
elseif p == 2
31+
_norm2_back(x, y, Δy)
32+
elseif p == 1
33+
_norm1_back(x, y, Δy)
34+
elseif p == Inf
35+
_normInf_back(x, y, Δy)
36+
elseif p == -Inf
37+
_normInf_back(x, y, Δy)
38+
else
39+
_normp_back_x(x, p, y, Δy)
40+
end
41+
end
42+
∂p = @thunk _normp_back_p(x, p, y, Δy)
43+
return (NO_FIELDS, ∂x, ∂p)
44+
end
45+
norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
46+
return y, norm_pullback
47+
end
48+
function rrule(
49+
::typeof(norm),
50+
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
51+
)
52+
y = LinearAlgebra.norm(x)
53+
function norm_pullback(Δy)
54+
∂x = if isempty(x)
55+
zero.(x) .* (zero(y) * zero(real(Δy)))
56+
else
57+
_norm2_back(x, y, Δy)
58+
end
59+
return (NO_FIELDS, ∂x)
60+
end
61+
norm_pullback(::Zero) = (NO_FIELDS, Zero())
62+
return y, norm_pullback
63+
end
64+
function rrule(
65+
::typeof(norm),
66+
x::Union{LinearAlgebra.TransposeAbsVec, LinearAlgebra.AdjointAbsVec},
67+
p::Real,
68+
)
69+
y, inner_pullback = rrule(norm, parent(x), p)
70+
function norm_pullback(Δy)
71+
(∂self, ∂x′, ∂p) = inner_pullback(Δy)
72+
fdual = x isa Transpose ? transpose : adjoint
73+
∂x = @thunk fdual(unthunk(∂x′))
74+
return (∂self, ∂x, ∂p)
75+
end
76+
return y, norm_pullback
77+
end
78+
function rrule(::typeof(norm), x::Number, p::Real)
79+
y = norm(x, p)
80+
function norm_pullback(Δy)
81+
∂x = if iszero(Δy) || iszero(p)
82+
zero(x) * zero(real(Δy))
83+
else
84+
signx = x isa Real ? sign(x) : x * pinv(y)
85+
signx * real(Δy)
86+
end
87+
return (NO_FIELDS, ∂x, Zero())
88+
end
89+
norm_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
90+
return y, norm_pullback
91+
end
92+
93+
#####
94+
##### `normp`
95+
#####
96+
97+
function rrule(
98+
::typeof(LinearAlgebra.normp),
99+
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
100+
p,
101+
)
102+
y = LinearAlgebra.normp(x, p)
103+
function normp_pullback(Δy)
104+
∂x = @thunk _normp_back_x(x, p, y, Δy)
105+
∂p = @thunk _normp_back_p(x, p, y, Δy)
106+
return (NO_FIELDS, ∂x, ∂p)
107+
end
108+
normp_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
109+
return y, normp_pullback
110+
end
111+
112+
function _normp_back_x(x, p, y, Δy)
113+
c = real(Δy) / y
114+
∂x = broadcast(x) do xi
115+
a = norm(xi)
116+
∂xi = xi * ((a / y)^(p - 2) * c)
117+
return ifelse(isfinite(∂xi), ∂xi, zero(∂xi))
118+
end
119+
return ∂x
120+
end
121+
122+
function _normp_back_p(x, p, y, Δy)
123+
y > 0 && isfinite(y) && !iszero(p) || return zero(real(Δy)) * zero(y) / one(p)
124+
s = sum(x) do xi
125+
a = norm(xi)
126+
c = (a / y)^(p - 1) * a * log(a)
127+
return ifelse(isfinite(c), c, zero(c))
128+
end
129+
∂p = real(Δy) * (s - y * log(y)) / p
130+
return ∂p
131+
end
132+
133+
#####
134+
##### `normMinusInf`/`normInf`
135+
#####
136+
137+
function rrule(
138+
::typeof(LinearAlgebra.normMinusInf),
139+
x::Union{StridedArray, LinearAlgebra.AbstractTriangular, Diagonal},
140+
)
141+
y = LinearAlgebra.normMinusInf(x)
142+
normMinusInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
143+
normMinusInf_pullback(::Zero) = (NO_FIELDS, Zero())
144+
return y, normMinusInf_pullback
145+
end
146+
147+
function rrule(
148+
::typeof(LinearAlgebra.normInf),
149+
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
150+
)
151+
y = LinearAlgebra.normInf(x)
152+
normInf_pullback(Δy) = (NO_FIELDS, _normInf_back(x, y, Δy))
153+
normInf_pullback(::Zero) = (NO_FIELDS, Zero())
154+
return y, normInf_pullback
155+
end
156+
157+
function _normInf_back(x, y, Δy)
158+
Δu = real(Δy)
159+
T = typeof(zero(float(eltype(x))) * zero(Δu))
160+
∂x = fill!(similar(x, T), 0)
161+
# if multiple `xi`s have the exact same norm, then they must have been identically
162+
# produced, e.g. with `fill`. So we set only one to be non-zero.
163+
# we choose last index to match the `frule`.
164+
yind = findlast(xi -> norm(xi) == y, x)
165+
yind === nothing && throw(ArgumentError("y is not the correct norm of x"))
166+
@inbounds ∂x[yind] = sign(x[yind]) * Δu
167+
return ∂x
168+
end
169+
170+
#####
171+
##### `norm1`
172+
#####
173+
174+
function rrule(
175+
::typeof(LinearAlgebra.norm1),
176+
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
177+
)
178+
y = LinearAlgebra.norm1(x)
179+
norm1_pullback(Δy) = (NO_FIELDS, _norm1_back(x, y, Δy))
180+
norm1_pullback(::Zero) = (NO_FIELDS, Zero())
181+
return y, norm1_pullback
182+
end
183+
184+
_norm1_back(x, y, Δy) = sign.(x) .* real(Δy)
185+
186+
#####
187+
##### `norm2`
188+
#####
189+
190+
function frule((_, Δx), ::typeof(LinearAlgebra.norm2), x)
191+
y = LinearAlgebra.norm2(x)
192+
return y, _norm2_forward(x, Δx, y)
193+
end
194+
195+
function rrule(
196+
::typeof(LinearAlgebra.norm2),
197+
x::Union{StridedArray,LinearAlgebra.AbstractTriangular,Diagonal},
198+
)
199+
y = LinearAlgebra.norm2(x)
200+
norm2_pullback(Δy) = (NO_FIELDS, _norm2_back(x, y, Δy))
201+
norm2_pullback(::Zero) = (NO_FIELDS, Zero())
202+
return y, norm2_pullback
203+
end
204+
205+
function _norm2_forward(x, Δx, y)
206+
∂y = real(dot(x, Δx)) * pinv(y)
207+
return ∂y
208+
end
209+
_norm2_back(x, y, Δy) = x .* (real(Δy) * pinv(y))
210+
211+
#####
212+
##### `normalize`
213+
#####
214+
215+
function rrule(::typeof(normalize), x::AbstractVector, p::Real)
216+
nrm, inner_pullback = rrule(norm, x, p)
217+
Ty = typeof(first(x) / nrm)
218+
y = copyto!(similar(x, Ty), x)
219+
LinearAlgebra.__normalize!(y, nrm)
220+
function normalize_pullback(Δy)
221+
invnrm = pinv(nrm)
222+
∂nrm = -dot(y, Δy) * invnrm
223+
(_, ∂xnorm, ∂p) = inner_pullback(∂nrm)
224+
∂x = @thunk unthunk(∂xnorm) .+ Δy .* invnrm
225+
return (NO_FIELDS, ∂x, ∂p)
226+
end
227+
normalize_pullback(::Zero) = (NO_FIELDS, Zero(), Zero())
228+
return y, normalize_pullback
229+
end
230+
function rrule(::typeof(normalize), x::AbstractVector)
231+
nrm = LinearAlgebra.norm2(x)
232+
Ty = typeof(first(x) / nrm)
233+
y = copyto!(similar(x, Ty), x)
234+
LinearAlgebra.__normalize!(y, nrm)
235+
function normalize_pullback(Δy)
236+
∂x = (Δy .- real(dot(y, Δy)) .* y) .* pinv(nrm)
237+
return (NO_FIELDS, ∂x)
238+
end
239+
normalize_pullback(::Zero) = (NO_FIELDS, Zero())
240+
return y, normalize_pullback
241+
end

src/rulesets/LinearAlgebra/structured.jl

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -86,86 +86,6 @@ function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real})
8686
return D * V, times_pullback
8787
end
8888

89-
#####
90-
##### `Symmetric`/`Hermitian`
91-
#####
92-
93-
function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
94-
return T(A, uplo), T(ΔA, uplo)
95-
end
96-
97-
function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo)
98-
Ω = T(A, uplo)
99-
function HermOrSym_pullback(ΔΩ)
100-
return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist())
101-
end
102-
return Ω, HermOrSym_pullback
103-
end
104-
105-
function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
106-
return TM(A), TM(_symherm_forward(A, ΔA))
107-
end
108-
function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym)
109-
return Array(A), Array(_symherm_forward(A, ΔA))
110-
end
111-
112-
function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym)
113-
function Matrix_pullback(ΔΩ)
114-
TA = _symhermtype(A)
115-
T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)}
116-
uplo = A.uplo
117-
∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo)
118-
return NO_FIELDS, ∂A
119-
end
120-
return TM(A), Matrix_pullback
121-
end
122-
rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A)
123-
124-
# Get type (Symmetric or Hermitian) from type or matrix
125-
_symhermtype(::Type{<:Symmetric}) = Symmetric
126-
_symhermtype(::Type{<:Hermitian}) = Hermitian
127-
_symhermtype(A) = _symhermtype(typeof(A))
128-
129-
# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω
130-
function _symherm_forward(A, ΔA)
131-
TA = _symhermtype(A)
132-
return if ΔA isa TA
133-
ΔA
134-
else
135-
TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo)
136-
end
137-
end
138-
139-
# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A
140-
_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo)
141-
function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo)
142-
return _symmetric_back(ΔΩ, uplo)
143-
end
144-
_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo)
145-
_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo)
146-
147-
function _symmetric_back(ΔΩ, uplo)
148-
L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ)
149-
return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D
150-
end
151-
_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ
152-
_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ))
153-
_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ)
154-
155-
function _hermitian_back(ΔΩ, uplo)
156-
L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ))
157-
return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD
158-
end
159-
_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ)
160-
function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo)
161-
∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ)))
162-
return if istriu(ΔΩ)
163-
return Matrix(uplo == 'U' ? ∂UL : ∂UL')
164-
else
165-
return Matrix(uplo == 'U' ? ∂UL' : ∂UL)
166-
end
167-
end
168-
16989
#####
17090
##### `Adjoint`
17191
#####

0 commit comments

Comments
 (0)