Skip to content

Commit 76ef95c

Browse files
sethaxenoxinabox
andauthored
Add rules for solutions to Sylvester and Lyapunov equations (#384)
* Add rules for sylvester solution * Add rules for lyap solution * Add trsyl! rule * Increment version number * Remove overflowing test FiniteDifferences only consistently works during overflow, but locally it works sometimes, which is enough to know that the scaling is correctly handled. * Increment version number * Update test/rulesets/LinearAlgebra/lapack.jl Co-authored-by: Lyndon White <oxinabox@ucc.asn.au> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
1 parent 99c58a9 commit 76ef95c

File tree

7 files changed

+156
-1
lines changed

7 files changed

+156
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "0.7.53"
3+
version = "0.7.54"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/ChainRules.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ include("rulesets/Statistics/statistics.jl")
4242

4343
include("rulesets/LinearAlgebra/utils.jl")
4444
include("rulesets/LinearAlgebra/blas.jl")
45+
include("rulesets/LinearAlgebra/lapack.jl")
4546
include("rulesets/LinearAlgebra/dense.jl")
4647
include("rulesets/LinearAlgebra/norm.jl")
4748
include("rulesets/LinearAlgebra/matfun.jl")

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,86 @@ function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T}
239239
end
240240
return Y, pinv_pullback
241241
end
242+
243+
#####
244+
##### `sylvester`
245+
#####
246+
247+
# included because the primal uses `schur`, for which we don't have a rule
248+
249+
function frule(
250+
(_, ΔA, ΔB, ΔC),
251+
::typeof(sylvester),
252+
A::StridedMatrix{T},
253+
B::StridedMatrix{T},
254+
C::StridedMatrix{T},
255+
) where {T<:BlasFloat}
256+
RA, QA = schur(A)
257+
RB, QB = schur(B)
258+
D = QA' * (C * QB)
259+
Y, scale = LAPACK.trsyl!('N', 'N', RA, RB, D)
260+
Ω = rmul!(QA * (Y * QB'), -inv(scale))
261+
∂D = QA' * (mul!(muladd(ΔA, Ω, ΔC), Ω, ΔB, true, true) * QB)
262+
∂Y, scale2 = LAPACK.trsyl!('N', 'N', RA, RB, ∂D)
263+
∂Ω = rmul!(QA * (∂Y * QB'), -inv(scale2))
264+
return Ω, ∂Ω
265+
end
266+
267+
# included because the primal mutates and uses `schur` and LAPACK
268+
269+
function rrule(
270+
::typeof(sylvester), A::StridedMatrix{T}, B::StridedMatrix{T}, C::StridedMatrix{T}
271+
) where {T<:BlasFloat}
272+
RA, QA = schur(A)
273+
RB, QB = schur(B)
274+
D = QA' * (C * QB)
275+
Y, scale = LAPACK.trsyl!('N', 'N', RA, RB, D)
276+
Ω = rmul!(QA * (Y * QB'), -inv(scale))
277+
function sylvester_pullback(ΔΩ)
278+
∂Ω = T <: Real ? real(ΔΩ) : ΔΩ
279+
∂Y = QA' * (∂Ω * QB)
280+
trans = T <: Complex ? 'C' : 'T'
281+
∂D, scale2 = LAPACK.trsyl!(trans, trans, RA, RB, ∂Y)
282+
∂Z = rmul!(QA * (∂D * QB'), -inv(scale2))
283+
return NO_FIELDS, @thunk(∂Z * Ω'), @thunk' * ∂Z), @thunk(∂Z * inv(scale))
284+
end
285+
return Ω, sylvester_pullback
286+
end
287+
288+
#####
289+
##### `lyap`
290+
#####
291+
292+
# included because the primal uses `schur`, for which we don't have a rule
293+
294+
function frule(
295+
(_, ΔA, ΔC), ::typeof(lyap), A::StridedMatrix{T}, C::StridedMatrix{T}
296+
) where {T<:BlasFloat}
297+
R, Q = schur(A)
298+
D = Q' * (C * Q)
299+
Y, scale = LAPACK.trsyl!('N', T <: Complex ? 'C' : 'T', R, R, D)
300+
Ω = rmul!(Q * (Y * Q'), -inv(scale))
301+
∂D = Q' * (mul!(muladd(ΔA, Ω, ΔC), Ω, ΔA', true, true) * Q)
302+
∂Y, scale2 = LAPACK.trsyl!('N', T <: Complex ? 'C' : 'T', R, R, ∂D)
303+
∂Ω = rmul!(Q * (∂Y * Q'), -inv(scale2))
304+
return Ω, ∂Ω
305+
end
306+
307+
# included because the primal mutates and uses `schur` and LAPACK
308+
309+
function rrule(
310+
::typeof(lyap), A::StridedMatrix{T}, C::StridedMatrix{T}
311+
) where {T<:BlasFloat}
312+
R, Q = schur(A)
313+
D = Q' * (C * Q)
314+
Y, scale = LAPACK.trsyl!('N', T <: Complex ? 'C' : 'T', R, R, D)
315+
Ω = rmul!(Q * (Y * Q'), -inv(scale))
316+
function lyap_pullback(ΔΩ)
317+
∂Ω = T <: Real ? real(ΔΩ) : ΔΩ
318+
∂Y = Q' * (∂Ω * Q)
319+
∂D, scale2 = LAPACK.trsyl!(T <: Complex ? 'C' : 'T', 'N', R, R, ∂Y)
320+
∂Z = rmul!(Q * (∂D * Q'), -inv(scale2))
321+
return NO_FIELDS, @thunk(mul!(∂Z * Ω', ∂Z', Ω, true, true)), @thunk(∂Z * inv(scale))
322+
end
323+
return Ω, lyap_pullback
324+
end

src/rulesets/LinearAlgebra/lapack.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#####
2+
##### `LAPACK.trsyl!`
3+
#####
4+
5+
function ChainRules.frule(
6+
(_, _, _, ΔA, ΔB, ΔC),
7+
::typeof(LAPACK.trsyl!),
8+
transa::AbstractChar,
9+
transb::AbstractChar,
10+
A::AbstractMatrix{T},
11+
B::AbstractMatrix{T},
12+
C::AbstractMatrix{T},
13+
isgn::Int,
14+
) where {T<:BlasFloat}
15+
C, scale = LAPACK.trsyl!(transa, transb, A, B, C, isgn)
16+
Y = (C, scale)
17+
ΔAtrans = transa === 'T' ? transpose(ΔA) : (transa === 'C' ? ΔA' : ΔA)
18+
ΔBtrans = transb === 'T' ? transpose(ΔB) : (transb === 'C' ? ΔB' : ΔB)
19+
mul!(ΔC, ΔAtrans, C, -1, scale)
20+
mul!(ΔC, C, ΔBtrans, -isgn, true)
21+
ΔC, scale2 = LAPACK.trsyl!(transa, transb, A, B, ΔC, isgn)
22+
rmul!(ΔC, inv(scale2))
23+
∂Y = Composite{typeof(Y)}(ΔC, Zero())
24+
return Y, ∂Y
25+
end

test/rulesets/LinearAlgebra/dense.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,22 @@
102102
test_frule(tr, randn(4, 4))
103103
test_rrule(tr, randn(4, 4))
104104
end
105+
@testset "sylvester" begin
106+
@testset "T=$T, m=$m, n=$n" for T in (Float64, ComplexF64), m in (2, 3), n in (1, 3)
107+
A = randn(T, m, m)
108+
B = randn(T, n, n)
109+
C = randn(T, m, n)
110+
test_frule(sylvester, A, B, C)
111+
test_rrule(sylvester, A, B, C)
112+
end
113+
end
114+
@testset "lyap" begin
115+
n = 3
116+
@testset "Float64" for T in (Float64, ComplexF64)
117+
A = randn(T, n, n)
118+
C = randn(T, n, n)
119+
test_frule(lyap, A, C)
120+
test_rrule(lyap, A, C)
121+
end
122+
end
105123
end

test/rulesets/LinearAlgebra/lapack.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
@testset "LAPACK" begin
2+
@testset "trsyl!" begin
3+
@testset "T=$T, m=$m, n=$n, transa='$transa', transb='$transb', isgn=$isgn" for
4+
T in (Float64, ComplexF64),
5+
transa in (T <: Real ? ('N', 'C', 'T') : ('N', 'C')),
6+
transb in (T <: Real ? ('N', 'C', 'T') : ('N', 'C')),
7+
m in (2, 3),
8+
n in (1, 3),
9+
isgn in (1, -1)
10+
11+
# make A and B quasi upper-triangular (or upper-triangular for complex)
12+
# and their tangents have the same sparsity pattern
13+
A = schur(randn(T, m, m)).T
14+
B = schur(randn(T, n, n)).T
15+
C = randn(T, m, n)
16+
test_frule(
17+
LAPACK.trsyl!,
18+
transa nothing,
19+
transb nothing,
20+
A rand_tangent(A) .* (!iszero).(A), # Match sparsity pattern
21+
B rand_tangent(B) .* (!iszero).(B),
22+
C,
23+
isgn nothing,
24+
)
25+
end
26+
end
27+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ println("Testing ChainRules.jl")
4848
include_test("rulesets/LinearAlgebra/symmetric.jl")
4949
include_test("rulesets/LinearAlgebra/factorization.jl")
5050
include_test("rulesets/LinearAlgebra/blas.jl")
51+
include_test("rulesets/LinearAlgebra/lapack.jl")
5152
end
5253
println()
5354

0 commit comments

Comments
 (0)