Skip to content

Commit 63bffb2

Browse files
authored
more general division by triangular matrices (#837)
1 parent b761bd5 commit 63bffb2

File tree

2 files changed

+65
-189
lines changed

2 files changed

+65
-189
lines changed

src/triangular.jl

Lines changed: 46 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
@inline transpose(A::LinearAlgebra.LowerTriangular{<:Any,<:StaticMatrix}) =
2-
LinearAlgebra.UpperTriangular(transpose(A.data))
3-
@inline adjoint(A::LinearAlgebra.LowerTriangular{<:Any,<:StaticMatrix}) =
4-
LinearAlgebra.UpperTriangular(adjoint(A.data))
5-
@inline transpose(A::LinearAlgebra.UpperTriangular{<:Any,<:StaticMatrix}) =
6-
LinearAlgebra.LowerTriangular(transpose(A.data))
7-
@inline adjoint(A::LinearAlgebra.UpperTriangular{<:Any,<:StaticMatrix}) =
8-
LinearAlgebra.LowerTriangular(adjoint(A.data))
1+
@inline transpose(A::LowerTriangular{<:Any,<:StaticMatrix}) =
2+
UpperTriangular(transpose(A.data))
3+
@inline adjoint(A::LowerTriangular{<:Any,<:StaticMatrix}) =
4+
UpperTriangular(adjoint(A.data))
5+
@inline transpose(A::UnitLowerTriangular{<:Any,<:StaticMatrix}) =
6+
UnitUpperTriangular(transpose(A.data))
7+
@inline adjoint(A::UnitLowerTriangular{<:Any,<:StaticMatrix}) =
8+
UnitUpperTriangular(adjoint(A.data))
9+
@inline transpose(A::UpperTriangular{<:Any,<:StaticMatrix}) =
10+
LowerTriangular(transpose(A.data))
11+
@inline adjoint(A::UpperTriangular{<:Any,<:StaticMatrix}) =
12+
LowerTriangular(adjoint(A.data))
13+
@inline transpose(A::UnitUpperTriangular{<:Any,<:StaticMatrix}) =
14+
UnitLowerTriangular(transpose(A.data))
15+
@inline adjoint(A::UnitUpperTriangular{<:Any,<:StaticMatrix}) =
16+
UnitLowerTriangular(adjoint(A.data))
917
@inline Base.:*(A::Adjoint{<:Any,<:StaticVector}, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) =
1018
adjoint(adjoint(B) * adjoint(A))
1119
@inline Base.:*(A::Transpose{<:Any,<:StaticVector}, B::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}) =
@@ -15,193 +23,52 @@
1523
@inline Base.:*(A::LinearAlgebra.AbstractTriangular{<:Any,<:StaticMatrix}, B::Transpose{<:Any,<:StaticVector}) =
1624
transpose(transpose(B) * transpose(A))
1725

18-
const StaticULT = Union{UpperTriangular{<:Any,<:StaticMatrix},LowerTriangular{<:Any,<:StaticMatrix}}
26+
const StaticULT{TA} = Union{UpperTriangular{TA,<:StaticMatrix},LowerTriangular{TA,<:StaticMatrix},UnitUpperTriangular{TA,<:StaticMatrix},UnitLowerTriangular{TA,<:StaticMatrix}}
1927

20-
@inline Base.:\(A::StaticULT, B::StaticVecOrMat) = _A_ldiv_B(Size(A), Size(B), A, B)
28+
@inline Base.:\(A::StaticULT, B::StaticVecOrMatLike) = _A_ldiv_B(Size(A), Size(B), A, B)
29+
@inline Base.:/(A::StaticVecOrMatLike, B::StaticULT) = transpose(transpose(B) \ transpose(A))
2130

22-
@generated function _A_ldiv_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB}
31+
@generated function _A_ldiv_B(::Size{sa}, ::Size{sb}, A::StaticULT{TA}, B::StaticVecOrMatLike{TB}) where {sa,sb,TA,TB}
2332
m = sb[1]
2433
n = length(sb) > 1 ? sb[2] : 1
2534
if m != sa[1]
2635
throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m"))
2736
end
2837

2938
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
30-
init = [:($(X[i,j]) = B[$(LinearIndices(sb)[i, j])]) for i = 1:m, j = 1:n]
3139

32-
code = Expr(:block)
33-
for k = 1:n
34-
for j = m:-1:1
35-
if k == 1
36-
push!(code.args, :(A.data[$(LinearIndices(sa)[j, j])] == zero(A.data[$(LinearIndices(sa)[j, j])]) && throw(LinearAlgebra.SingularException($j))))
40+
isunitdiag = A <: Union{UnitUpperTriangular, UnitLowerTriangular}
41+
isupper = A <: Union{UnitUpperTriangular, UpperTriangular}
42+
43+
j_range = isupper ? (m:-1:1) : (1:m)
44+
45+
init = gen_by_access(B, :B) do access_b
46+
init_exprs = [:($(X[i,j]) = $(uplo_access(sb, :b, i, j, access_b))) for i = 1:m, j = 1:n]
47+
code = Expr(:block, init_exprs...)
48+
for k = 1:n
49+
for j = j_range
50+
if !isunitdiag && k == 1
51+
push!(code.args, :(A.data[$(LinearIndices(sa)[j, j])] == zero(A.data[$(LinearIndices(sa)[j, j])]) && throw(LinearAlgebra.SingularException($j))))
52+
end
53+
if isunitdiag
54+
push!(code.args, :($(X[j,k]) = oneunit(TA) \ $(X[j,k])))
55+
else
56+
push!(code.args, :($(X[j,k]) = A.data[$(LinearIndices(sa)[j, j])] \ $(X[j,k])))
57+
end
58+
i_range = isupper ? (j-1:-1:1) : (j+1:m)
59+
for i = i_range
60+
push!(code.args, :($(X[i,k]) -= A.data[$(LinearIndices(sa)[i, j])]*$(X[j,k])))
61+
end
3762
end
38-
push!(code.args, :($(X[j,k]) = A.data[$(LinearIndices(sa)[j, j])] \ $(X[j,k])))
39-
for i = j-1:-1:1
40-
push!(code.args, :($(X[i,k]) -= A.data[$(LinearIndices(sa)[i, j])]*$(X[j,k])))
41-
end
42-
end
43-
end
44-
45-
return quote
46-
@_inline_meta
47-
@inbounds $(Expr(:block, init...))
48-
@inbounds $code
49-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
50-
@inbounds return similar_type(B, TAB)(tuple($(X...)))
51-
end
52-
end
53-
54-
@generated function _A_ldiv_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB}
55-
m = sb[1]
56-
n = length(sb) > 1 ? sb[2] : 1
57-
if m != sa[1]
58-
throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m"))
59-
end
60-
61-
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
62-
init = [:($(X[i,j]) = B[$(LinearIndices(sb)[i, j])]) for i = 1:m, j = 1:n]
63-
64-
code = Expr(:block)
65-
for k = 1:n
66-
for j = 1:m
67-
if k == 1
68-
push!(code.args, :(A.data[$(LinearIndices(sa)[j, j])] == zero(A.data[$(LinearIndices(sa)[j, j])]) && throw(LinearAlgebra.SingularException($j))))
69-
end
70-
push!(code.args, :($(X[j,k]) = A.data[$(LinearIndices(sa)[j, j])] \ $(X[j,k])))
71-
for i = j+1:m
72-
push!(code.args, :($(X[i,k]) -= A.data[$(LinearIndices(sa)[i, j])]*$(X[j,k])))
73-
end
74-
end
75-
end
76-
77-
return quote
78-
@_inline_meta
79-
@inbounds $(Expr(:block, init...))
80-
@inbounds $code
81-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
82-
@inbounds return similar_type(B, TAB)(tuple($(X...)))
83-
end
84-
end
85-
86-
@generated function _Ac_ldiv_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB}
87-
m = sb[1]
88-
n = length(sb) > 1 ? sb[2] : 1
89-
if m != sa[1]
90-
throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m"))
91-
end
92-
93-
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
94-
95-
code = Expr(:block)
96-
for k = 1:n
97-
for j = 1:m
98-
ex = :(B[$(LinearIndices(sb)[j, k])])
99-
for i = 1:j-1
100-
ex = :($ex - A.data[$(LinearIndices(sa)[i, j])]'*$(X[i,k]))
101-
end
102-
if k == 1
103-
push!(code.args, :(A.data[$(LinearIndices(sa)[j, j])] == zero(A.data[$(LinearIndices(sa)[j, j])]) && throw(LinearAlgebra.SingularException($j))))
104-
end
105-
push!(code.args, :($(X[j,k]) = A.data[$(LinearIndices(sa)[j, j])]' \ $ex))
106-
end
107-
end
108-
109-
return quote
110-
@_inline_meta
111-
@inbounds $code
112-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
113-
@inbounds return similar_type(B, TAB)(tuple($(X...)))
114-
end
115-
end
116-
117-
@generated function _At_ldiv_B(::Size{sa}, ::Size{sb}, A::UpperTriangular{<:TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB}
118-
m = sb[1]
119-
n = length(sb) > 1 ? sb[2] : 1
120-
if m != sa[1]
121-
throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m"))
122-
end
123-
124-
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
125-
126-
code = Expr(:block)
127-
for k = 1:n
128-
for j = 1:m
129-
ex = :(B[$(LinearIndices(sb)[j, k])])
130-
for i = 1:j-1
131-
ex = :($ex - A.data[$(LinearIndices(sa)[i, j])]*$(X[i,k]))
132-
end
133-
if k == 1
134-
push!(code.args, :(A.data[$(LinearIndices(sa)[j, j])] == zero(A.data[$(LinearIndices(sa)[j, j])]) && throw(LinearAlgebra.SingularException($j))))
135-
end
136-
push!(code.args, :($(X[j,k]) = A.data[$(LinearIndices(sa)[j, j])] \ $ex))
137-
end
138-
end
139-
140-
return quote
141-
@_inline_meta
142-
@inbounds $code
143-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
144-
@inbounds return similar_type(B, TAB)(tuple($(X...)))
145-
end
146-
end
147-
148-
@generated function _Ac_ldiv_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB}
149-
m = sb[1]
150-
n = length(sb) > 1 ? sb[2] : 1
151-
if m != sa[1]
152-
throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m"))
153-
end
154-
155-
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
156-
157-
code = Expr(:block)
158-
for k = 1:n
159-
for j = m:-1:1
160-
ex = :(B[$(LinearIndices(sb)[j, k])])
161-
for i = m:-1:j+1
162-
ex = :($ex - A.data[$(LinearIndices(sa)[i, j])]'*$(X[i,k]))
163-
end
164-
if k == 1
165-
push!(code.args, :(A.data[$(LinearIndices(sa)[j, j])] == zero(A.data[$(LinearIndices(sa)[j, j])]) && throw(LinearAlgebra.SingularException($j))))
166-
end
167-
push!(code.args, :($(X[j,k]) = A.data[$(LinearIndices(sa)[j, j])]' \ $ex))
168-
end
169-
end
170-
171-
return quote
172-
@_inline_meta
173-
@inbounds $code
174-
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
175-
@inbounds return similar_type(B, TAB)(tuple($(X...)))
176-
end
177-
end
178-
179-
@generated function _At_ldiv_B(::Size{sa}, ::Size{sb}, A::LowerTriangular{<:TA,<:StaticMatrix}, B::StaticVecOrMat{TB}) where {sa,sb,TA,TB}
180-
m = sb[1]
181-
n = length(sb) > 1 ? sb[2] : 1
182-
if m != sa[1]
183-
throw(DimensionMismatch("right hand side B needs first dimension of size $(sa[1]), has size $m"))
184-
end
185-
186-
X = [Symbol("X_$(i)_$(j)") for i = 1:m, j = 1:n]
187-
188-
code = Expr(:block)
189-
for k = 1:n
190-
for j = m:-1:1
191-
ex = :(B[$(LinearIndices(sb)[j, k])])
192-
for i = m:-1:j+1
193-
ex = :($ex - A.data[$(LinearIndices(sa)[i, j])]*$(X[i,k]))
194-
end
195-
if k == 1
196-
push!(code.args, :(A.data[$(LinearIndices(sa)[j, j])] == zero(A.data[$(LinearIndices(sa)[j, j])]) && throw(LinearAlgebra.SingularException($j))))
197-
end
198-
push!(code.args, :($(X[j,k]) = A.data[$(LinearIndices(sa)[j, j])] \ $ex))
19963
end
64+
return code
20065
end
20166

20267
return quote
20368
@_inline_meta
204-
@inbounds $code
69+
b = mul_parent(B)
70+
Tb = TB
71+
@inbounds $init
20572
TAB = typeof((zero(TA)*zero(TB) + zero(TA)*zero(TB))/one(TA))
20673
@inbounds return similar_type(B, TAB)(tuple($(X...)))
20774
end

test/triangular.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,27 +102,36 @@ end
102102
@testset "Triangular-matrix division" begin
103103
for n in (1, 2, 3, 4),
104104
eltyA in (Float64, ComplexF64, Int),
105-
(t, uplo) in ((UpperTriangular, :U), (LowerTriangular, :L)),
106-
eltyB in (Float64, ComplexF64)
105+
(t, uplo) in ((UpperTriangular, :U), (LowerTriangular, :L), (UnitUpperTriangular, :U)),
106+
eltyB in (Float64, ComplexF64),
107+
tb in (identity, LowerTriangular, Symmetric)
107108

108109
A = t(eltyA == Int ? rand(1:7, n, n) : convert(Matrix{eltyA}, (eltyA <: Complex ? complex.(randn(n, n), randn(n, n)) : randn(n, n)) |> t -> cholesky(t't).U |> t -> uplo == :U ? t : adjoint(t)))
109-
B = convert(Matrix{eltyB}, eltyA <: Complex ? real(A)*ones(n, n) : A*ones(n, n))
110+
B = tb(convert(Matrix{eltyB}, eltyA <: Complex ? real(A)*ones(n, n) : A*ones(n, n)))
110111
SA = t(SMatrix{n,n}(A.data))
111-
SB = SMatrix{n,n}(B)
112+
SB = tb(SMatrix{n,n}(parent(B)))
112113

113-
@test (SA\SB[:,1])::SVector{n} A\B[:,1]
114+
if tb === identity
115+
@test (SA\SB[:,1])::SVector{n} A\B[:,1]
116+
@test (transpose(SA)\SB[:,1])::SVector{n} transpose(A)\B[:,1]
117+
@test (SA'\SB[:,1])::SVector{n} A'\B[:,1]
118+
end
114119
@test (SA\SB)::SMatrix{n,n} A\B
115-
@test (transpose(SA)\SB[:,1])::SVector{n} transpose(A)\B[:,1]
116120
@test (transpose(SA)\SB)::SMatrix{n,n} transpose(A)\B
117-
@test (SA'\SB[:,1])::SVector{n} A'\B[:,1]
118121
@test (SA'\SB)::SMatrix{n,n} A'\B
119122

120123
@test_throws DimensionMismatch SA\ones(SVector{n+2,eltyB})
121124
@test_throws DimensionMismatch transpose(SA)\ones(SVector{n+2,eltyB})
122125
@test_throws DimensionMismatch SA'\ones(SVector{n+2,eltyB})
123126

124-
@test_throws LinearAlgebra.SingularException t(zeros(SMatrix{n,n,eltyA}))\ones(SVector{n,eltyB})
125-
@test_throws LinearAlgebra.SingularException t(transpose(zeros(SMatrix{n,n,eltyA})))\ones(SVector{n,eltyB})
126-
@test_throws LinearAlgebra.SingularException t(zeros(SMatrix{n,n,eltyA}))'\ones(SVector{n,eltyB})
127+
if t != UnitUpperTriangular
128+
@test_throws LinearAlgebra.SingularException t(zeros(SMatrix{n,n,eltyA}))\ones(SVector{n,eltyB})
129+
@test_throws LinearAlgebra.SingularException t(transpose(zeros(SMatrix{n,n,eltyA})))\ones(SVector{n,eltyB})
130+
@test_throws LinearAlgebra.SingularException t(zeros(SMatrix{n,n,eltyA}))'\ones(SVector{n,eltyB})
131+
end
132+
133+
@test (SB/SA)::SMatrix{n,n} B/A
134+
@test (SB/transpose(SA))::SMatrix{n,n} B/transpose(A)
135+
@test (SB/SA')::SMatrix{n,n} B/A'
127136
end
128137
end

0 commit comments

Comments
 (0)