Skip to content

Commit 321a45b

Browse files
dkarraschKristofferC
authored andcommitted
symmetrize/transpose in getindex of SymTridiagonal (#31327)
* symmetrize/transpose in SymTridiagonal added block matrix docstring more consistency in SymTridiagonal(A) and diag fix x-ref fix and test diag remove x-ref for Transpose fix and test for type inference simplify constructor, fix use of symmetric update docstrings, add comment * don't symmetrize diagonal elements at construction, but at getindex * remove symmetry check at construction * don't symmetrize in SymTridiagonal(::Matrix) * fix doctests
1 parent 01c3a2e commit 321a45b

File tree

2 files changed

+66
-12
lines changed

2 files changed

+66
-12
lines changed

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,15 @@
33
#### Specialized matrix types ####
44

55
## (complex) symmetric tridiagonal matrices
6-
struct SymTridiagonal{T,V<:AbstractVector{T}} <: AbstractMatrix{T}
6+
struct SymTridiagonal{T, V<:AbstractVector{T}} <: AbstractMatrix{T}
77
dv::V # diagonal
8-
ev::V # subdiagonal
9-
function SymTridiagonal{T,V}(dv, ev) where {T,V<:AbstractVector{T}}
8+
ev::V # superdiagonal
9+
function SymTridiagonal{T, V}(dv, ev) where {T, V<:AbstractVector{T}}
1010
require_one_based_indexing(dv, ev)
1111
if !(length(dv) - 1 <= length(ev) <= length(dv))
1212
throw(DimensionMismatch("subdiagonal has wrong length. Has length $(length(ev)), but should be either $(length(dv) - 1) or $(length(dv))."))
1313
end
14-
new{T,V}(dv,ev)
14+
new{T, V}(dv, ev)
1515
end
1616
end
1717

@@ -23,6 +23,10 @@ sub/super-diagonal (`ev`), respectively. The result is of type `SymTridiagonal`
2323
and provides efficient specialized eigensolvers, but may be converted into a
2424
regular matrix with [`convert(Array, _)`](@ref) (or `Array(_)` for short).
2525
26+
For `SymTridiagonal` block matrices, the elements of `dv` are symmetrized.
27+
The argument `ev` is interpreted as the superdiagonal. Blocks from the
28+
subdiagonal are (materialized) transpose of the corresponding superdiagonal blocks.
29+
2630
# Examples
2731
```jldoctest
2832
julia> dv = [1, 2, 3, 4]
@@ -44,6 +48,23 @@ julia> SymTridiagonal(dv, ev)
4448
7 2 8 ⋅
4549
⋅ 8 3 9
4650
⋅ ⋅ 9 4
51+
52+
julia> A = SymTridiagonal(fill([1 2; 3 4], 3), fill([1 2; 3 4], 2));
53+
54+
julia> A[1,1]
55+
2×2 Symmetric{Int64,Array{Int64,2}}:
56+
1 2
57+
2 4
58+
59+
julia> A[1,2]
60+
2×2 Array{Int64,2}:
61+
1 2
62+
3 4
63+
64+
julia> A[2,1]
65+
2×2 Array{Int64,2}:
66+
1 3
67+
2 4
4768
```
4869
"""
4970
SymTridiagonal(dv::V, ev::V) where {T,V<:AbstractVector{T}} = SymTridiagonal{T}(dv, ev)
@@ -56,8 +77,8 @@ end
5677
"""
5778
SymTridiagonal(A::AbstractMatrix)
5879
59-
Construct a symmetric tridiagonal matrix from the diagonal and
60-
first sub/super-diagonal, of the symmetric matrix `A`.
80+
Construct a symmetric tridiagonal matrix from the diagonal and first superdiagonal
81+
of the symmetric matrix `A`.
6182
6283
# Examples
6384
```jldoctest
@@ -72,11 +93,18 @@ julia> SymTridiagonal(A)
7293
1 2 ⋅
7394
2 4 5
7495
⋅ 5 6
96+
97+
julia> B = reshape([[1 2; 2 3], [1 2; 3 4], [1 3; 2 4], [1 2; 2 3]], 2, 2);
98+
99+
julia> SymTridiagonal(B)
100+
2×2 SymTridiagonal{Array{Int64,2},Array{Array{Int64,2},1}}:
101+
[1 2; 2 3] [1 3; 2 4]
102+
[1 2; 3 4] [1 2; 2 3]
75103
```
76104
"""
77105
function SymTridiagonal(A::AbstractMatrix)
78-
if diag(A,1) == diag(A,-1)
79-
SymTridiagonal(diag(A,0), diag(A,1))
106+
if (diag(A, 1) == transpose.(diag(A, -1))) && all(issymmetric.(diag(A, 0)))
107+
SymTridiagonal(diag(A, 0), diag(A, 1))
80108
else
81109
throw(ArgumentError("matrix is not symmetric; cannot convert to SymTridiagonal"))
82110
end
@@ -139,13 +167,13 @@ adjoint(S::SymTridiagonal) = Adjoint(S)
139167
Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...)
140168
Base.copy(S::Transpose{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(transpose.(x)), (S.parent.dv, S.parent.ev))...)
141169

142-
function diag(M::SymTridiagonal, n::Integer=0)
170+
function diag(M::SymTridiagonal{<:Number}, n::Integer=0)
143171
# every branch call similar(..., ::Int) to make sure the
144172
# same vector type is returned independent of n
145173
absn = abs(n)
146174
if absn == 0
147175
return copyto!(similar(M.dv, length(M.dv)), M.dv)
148-
elseif absn==1
176+
elseif absn == 1
149177
return copyto!(similar(M.ev, length(M.ev)), M.ev)
150178
elseif absn <= size(M,1)
151179
return fill!(similar(M.dv, size(M,1)-absn), 0)
@@ -154,6 +182,22 @@ function diag(M::SymTridiagonal, n::Integer=0)
154182
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
155183
end
156184
end
185+
function diag(M::SymTridiagonal, n::Integer=0)
186+
# every branch call similar(..., ::Int) to make sure the
187+
# same vector type is returned independent of n
188+
if n == 0
189+
return copyto!(similar(M.dv, length(M.dv)), symmetric.(M.dv, :U))
190+
elseif n == 1
191+
return copyto!(similar(M.ev, length(M.ev)), M.ev)
192+
elseif n == -1
193+
return copyto!(similar(M.ev, length(M.ev)), transpose.(M.ev))
194+
elseif n <= size(M,1)
195+
throw(ArgumentError("requested diagonal contains undefined zeros of an array type"))
196+
else
197+
throw(ArgumentError(string("requested diagonal, $n, must be at least $(-size(M, 1)) ",
198+
"and at most $(size(M, 2)) for an $(size(M, 1))-by-$(size(M, 2)) matrix")))
199+
end
200+
end
157201

158202
+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, A.ev+B.ev)
159203
-(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, A.ev-B.ev)
@@ -390,9 +434,9 @@ function getindex(A::SymTridiagonal{T}, i::Integer, j::Integer) where T
390434
throw(BoundsError(A, (i,j)))
391435
end
392436
if i == j
393-
return A.dv[i]
437+
return symmetric(A.dv[i], :U)::symmetric_type(eltype(A.dv))
394438
elseif i == j + 1
395-
return A.ev[j]
439+
return copy(transpose(A.ev[j])) # materialized for type stability
396440
elseif i + 1 == j
397441
return A.ev[i]
398442
else

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,16 @@ end
398398
end
399399
end
400400

401+
@testset "SymTridiagonal block matrix" begin
402+
M = [1 2; 2 4]
403+
A = SymTridiagonal(fill(M, 3), fill(M, 2))
404+
@test @inferred A[1,1] == Symmetric(M)
405+
@test @inferred A[1,2] == M
406+
@test @inferred A[2,1] == transpose(M)
407+
@test @inferred diag(A, 1) == fill(M, 2)
408+
@test @inferred diag(A, 0) == fill(Symmetric(M), 3)
409+
@test @inferred diag(A, -1) == fill(transpose(M), 2)
410+
end
401411

402412
@testset "Issue 12068" begin
403413
@test SymTridiagonal([1, 2], [0])^3 == [1 0; 0 8]

0 commit comments

Comments
 (0)