Skip to content

Commit 62c8100

Browse files
jishnubdkarrasch
andauthored
convert to banded matrix types from AbstractMatrixes (#1212)
Co-authored-by: Daniel Karrasch <daniel.karrasch@posteo.de>
1 parent 907a202 commit 62c8100

File tree

5 files changed

+79
-10
lines changed

5 files changed

+79
-10
lines changed

src/bidiag.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ julia> Bidiagonal(A, :L) # contains the main diagonal and first subdiagonal of A
109109
⋅ ⋅ 4 4
110110
```
111111
"""
112-
function Bidiagonal(A::AbstractMatrix, uplo::Symbol)
113-
Bidiagonal(diag(A, 0), diag(A, uplo === :U ? 1 : -1), uplo)
112+
function (::Type{Bi})(A::AbstractMatrix, uplo::Symbol) where {Bi<:Bidiagonal}
113+
Bi(diag(A, 0), diag(A, uplo === :U ? 1 : -1), uplo)
114114
end
115115

116116

@@ -220,7 +220,12 @@ promote_rule(::Type{<:Tridiagonal}, ::Type{<:Bidiagonal}) = Tridiagonal
220220
AbstractMatrix{T}(A::Bidiagonal) where {T} = Bidiagonal{T}(A)
221221
AbstractMatrix{T}(A::Bidiagonal{T}) where {T} = copy(A)
222222

223-
convert(::Type{T}, m::AbstractMatrix) where {T<:Bidiagonal} = m isa T ? m : T(m)::T
223+
function convert(::Type{T}, A::AbstractMatrix) where T<:Bidiagonal
224+
checksquare(A)
225+
isbanded(A, -1, 1) || throw(InexactError(:convert, T, A))
226+
iszero(diagview(A, 1)) ? T(A, :L) :
227+
iszero(diagview(A, -1)) ? T(A, :U) : throw(InexactError(:convert, T, A))
228+
end
224229

225230
similar(B::Bidiagonal, ::Type{T}) where {T} = Bidiagonal(similar(B.dv, T), similar(B.ev, T), B.uplo)
226231
similar(B::Bidiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(B.dv, T, dims)

src/tridiag.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,21 @@ julia> SymTridiagonal(B)
107107
[1 2; 3 4] [1 2; 2 3]
108108
```
109109
"""
110-
function SymTridiagonal(A::AbstractMatrix)
110+
function (::Type{SymTri})(A::AbstractMatrix) where {SymTri <: SymTridiagonal}
111111
checksquare(A)
112112
du = diag(A, 1)
113113
d = diag(A)
114114
dl = diag(A, -1)
115-
if all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d)
116-
SymTridiagonal(d, du)
115+
if _checksymmetric(d, du, dl)
116+
SymTri(d, du)
117117
else
118118
throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal"))
119119
end
120120
end
121121

122+
_checksymmetric(d, du, dl) = all(((x, y),) -> x == transpose(y), zip(du, dl)) && all(issymmetric, d)
123+
_checksymmetric(A::AbstractMatrix) = _checksymmetric(diagview(A), diagview(A, 1), diagview(A, -1))
124+
122125
SymTridiagonal{T,V}(S::SymTridiagonal{T,V}) where {T,V<:AbstractVector{T}} = S
123126
SymTridiagonal{T,V}(S::SymTridiagonal) where {T,V<:AbstractVector{T}} =
124127
SymTridiagonal(convert(V, S.dv)::V, convert(V, S.ev)::V)
@@ -128,6 +131,11 @@ SymTridiagonal{T}(S::SymTridiagonal) where {T} =
128131
convert(AbstractVector{T}, S.ev)::AbstractVector{T})
129132
SymTridiagonal(S::SymTridiagonal) = S
130133

134+
function convert(::Type{T}, A::AbstractMatrix) where T<:SymTridiagonal
135+
checksquare(A)
136+
_checksymmetric(A) && isbanded(A, -1, 1) ? T(A) : throw(InexactError(:convert, T, A))
137+
end
138+
131139
AbstractMatrix{T}(S::SymTridiagonal) where {T} = SymTridiagonal{T}(S)
132140
AbstractMatrix{T}(S::SymTridiagonal{T}) where {T} = copy(S)
133141

@@ -597,7 +605,7 @@ julia> Tridiagonal(A)
597605
⋅ ⋅ 3 4
598606
```
599607
"""
600-
Tridiagonal(A::AbstractMatrix) = Tridiagonal(diag(A,-1), diag(A,0), diag(A,1))
608+
(::Type{Tri})(A::AbstractMatrix) where {Tri<:Tridiagonal} = Tri(diag(A,-1), diag(A,0), diag(A,1))
601609

602610
Tridiagonal(A::Tridiagonal) = A
603611
Tridiagonal{T}(A::Tridiagonal{T}) where {T} = A
@@ -619,6 +627,11 @@ function Tridiagonal{T,V}(A::Tridiagonal) where {T,V<:AbstractVector{T}}
619627
end
620628
end
621629

630+
function convert(::Type{T}, A::AbstractMatrix) where T<:Tridiagonal
631+
checksquare(A)
632+
isbanded(A, -1, 1) ? T(A) : throw(InexactError(:convert, T, A))
633+
end
634+
622635
size(M::Tridiagonal) = (n = length(M.d); (n, n))
623636
axes(M::Tridiagonal) = (ax = axes(M.d,1); (ax, ax))
624637

test/bidiag.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,4 +1158,21 @@ end
11581158
@test opnorm(B, Inf) == opnorm(Matrix(B), Inf)
11591159
end
11601160

1161+
@testset "convert to Bidiagonal" begin
1162+
M = diagm(0 => [1,2,3], 1=>[4,5])
1163+
B = convert(Bidiagonal, M)
1164+
@test B == Bidiagonal(M, :U)
1165+
M = diagm(0 => [1,2,3], -1=>[4,5])
1166+
B = convert(Bidiagonal, M)
1167+
@test B == Bidiagonal(M, :L)
1168+
B = convert(Bidiagonal{Int8}, M)
1169+
@test B == M
1170+
@test B isa Bidiagonal{Int8, Vector{Int8}}
1171+
B = convert(Bidiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M)
1172+
@test B == M
1173+
@test B isa Bidiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}
1174+
M = diagm(-1 => [1,2], 1=>[4,5])
1175+
@test_throws InexactError convert(Bidiagonal, M)
1176+
end
1177+
11611178
end # module TestBidiagonal

test/special.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Random.seed!(1)
4343
@test Matrix(convert(newtype, A)) == Matrix(A)
4444
end
4545
for newtype in [Diagonal, Bidiagonal]
46-
@test_throws ArgumentError convert(newtype,A)
46+
@test_throws Union{ArgumentError,InexactError} convert(newtype,A)
4747
end
4848
A = SymTridiagonal(a, zeros(n-1))
4949
@test Matrix(convert(Bidiagonal,A)) == Matrix(A)
@@ -57,7 +57,7 @@ Random.seed!(1)
5757
@test Matrix(convert(newtype, A)) == Matrix(A)
5858
end
5959
for newtype in [Diagonal, Bidiagonal]
60-
@test_throws ArgumentError convert(newtype,A)
60+
@test_throws Union{ArgumentError,InexactError} convert(newtype,A)
6161
end
6262
A = Tridiagonal(zeros(n-1), [1.0:n;], fill(1., n-1)) #not morally Diagonal
6363
@test Matrix(convert(Bidiagonal, A)) == Matrix(A)
@@ -79,7 +79,7 @@ Random.seed!(1)
7979
end
8080
A = UpperTriangular(triu(rand(n,n)))
8181
for newtype in [Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal]
82-
@test_throws ArgumentError convert(newtype,A)
82+
@test_throws Union{ArgumentError,InexactError} convert(newtype,A)
8383
end
8484

8585

test/tridiag.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,4 +1099,38 @@ end
10991099
@test opnorm(S, Inf) == opnorm(Matrix(S), Inf)
11001100
end
11011101

1102+
@testset "convert to Tridiagonal/SymTridiagonal" begin
1103+
@testset "Tridiagonal" begin
1104+
for M in [diagm(0 => [1,2,3], 1=>[4,5]),
1105+
diagm(0 => [1,2,3], 1=>[4,5], -1=>[6,7]),
1106+
diagm(-1 => [1,2], 1=>[4,5])]
1107+
B = convert(Tridiagonal, M)
1108+
@test B == Tridiagonal(M)
1109+
B = convert(Tridiagonal{Int8}, M)
1110+
@test B == M
1111+
@test B isa Tridiagonal{Int8}
1112+
B = convert(Tridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M)
1113+
@test B == M
1114+
@test B isa Tridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}
1115+
end
1116+
@test_throws InexactError convert(Tridiagonal, fill(5, 4, 4))
1117+
end
1118+
@testset "SymTridiagonal" begin
1119+
for M in [diagm(0 => [1,2,3], 1=>[4,5], -1=>[4,5]),
1120+
diagm(0 => [1,2,3]),
1121+
diagm(-1 => [1,2], 1=>[1,2])]
1122+
B = convert(SymTridiagonal, M)
1123+
@test B == SymTridiagonal(M)
1124+
B = convert(SymTridiagonal{Int8}, M)
1125+
@test B == M
1126+
@test B isa SymTridiagonal{Int8}
1127+
B = convert(SymTridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}, M)
1128+
@test B == M
1129+
@test B isa SymTridiagonal{Int8, OffsetVector{Int8, Vector{Int8}}}
1130+
end
1131+
@test_throws InexactError convert(SymTridiagonal, fill(5, 4, 4))
1132+
@test_throws InexactError convert(SymTridiagonal, diagm(0=>fill(NaN,4)))
1133+
end
1134+
end
1135+
11021136
end # module TestTridiagonal

0 commit comments

Comments
 (0)