Skip to content

Commit eb064b8

Browse files
authored
Support indexing Tridiagonal/SymTridiagonal block matrices (#1218)
This uses `diagzero` instead of `zero` to obtain the zero elements, which works for matrix-valued elements as well. The following works after this: ```julia julia> m = ones(2,2); julia> T = Tridiagonal(fill(m,3), fill(2m,4), fill(3m,3)) 4×4 Tridiagonal{Matrix{Float64}, Vector{Matrix{Float64}}}: [2.0 2.0; 2.0 2.0] [3.0 3.0; 3.0 3.0] … ⋅ [1.0 1.0; 1.0 1.0] [2.0 2.0; 2.0 2.0] ⋅ ⋅ [1.0 1.0; 1.0 1.0] [3.0 3.0; 3.0 3.0] ⋅ ⋅ [2.0 2.0; 2.0 2.0] julia> S = SymTridiagonal(fill(2m,4), fill(m,3)) 4×4 SymTridiagonal{Matrix{Float64}, Vector{Matrix{Float64}}}: [2.0 2.0; 2.0 2.0] [1.0 1.0; 1.0 1.0] … ⋅ [1.0 1.0; 1.0 1.0] [2.0 2.0; 2.0 2.0] ⋅ ⋅ [1.0 1.0; 1.0 1.0] [1.0 1.0; 1.0 1.0] ⋅ ⋅ [2.0 2.0; 2.0 2.0] ``` This was already supported for `Diagonal`/`Bidiagonal`, and this PR extends the feature to tridiagonal matrices as well.
2 parents 62c8100 + ec09767 commit eb064b8

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

src/tridiag.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ end
478478
elseif i + 1 == j
479479
return @inbounds A.ev[i]
480480
else
481-
return zero(T)
481+
return diagzero(A, i, j)
482482
end
483483
end
484484

@@ -740,7 +740,7 @@ end
740740
elseif i + 1 == j
741741
return @inbounds A.du[i]
742742
else
743-
return zero(T)
743+
return diagzero(A, i, j)
744744
end
745745
end
746746

@@ -753,7 +753,7 @@ end
753753
elseif b.band == 1
754754
return @inbounds A.du[b.index]
755755
else
756-
return zero(T)
756+
return diagzero(A, b)
757757
end
758758
end
759759

test/bidiag.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -835,9 +835,6 @@ end
835835
end
836836

837837
@testset "non-standard axes" begin
838-
LinearAlgebra.diagzero(T::Type, ax::Tuple{SizedArrays.SOneTo, Vararg{SizedArrays.SOneTo}}) =
839-
zeros(T, ax)
840-
841838
s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
842839
B = Bidiagonal(fill(s,4), fill(s,3), :U)
843840
@test @inferred(B[2,1]) isa typeof(s)

test/tridiag.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,35 @@ end
10991099
@test opnorm(S, Inf) == opnorm(Matrix(S), Inf)
11001100
end
11011101

1102+
@testset "block-bidiagonal matrix indexing" begin
1103+
dv = [ones(4,3), ones(2,2).*2, ones(2,3).*3, ones(4,4).*4]
1104+
evu = [ones(4,2), ones(2,3).*2, ones(2,4).*3]
1105+
evl = [ones(2,3), ones(2,2).*2, ones(4,3).*3]
1106+
T = Tridiagonal(evl, dv, evu)
1107+
# check that all the matrices along a column have the same number of columns,
1108+
# and the matrices along a row have the same number of rows
1109+
for j in axes(T, 2), i in 2:size(T, 1)
1110+
@test size(T[i,j], 2) == size(T[1,j], 2)
1111+
@test size(T[i,j], 1) == size(T[i,1], 1)
1112+
if j < i-1 || j > i + 1
1113+
@test iszero(T[i,j])
1114+
end
1115+
end
1116+
1117+
@testset "non-standard axes" begin
1118+
s = SizedArrays.SizedArray{(2,2)}([1 2; 3 4])
1119+
T = Tridiagonal(fill(s,3), fill(s,4), fill(s,3))
1120+
@test @inferred(T[3,1]) isa typeof(s)
1121+
@test all(iszero, T[3,1])
1122+
end
1123+
1124+
# SymTridiagonal requires square diagonal blocks
1125+
dv = [fill(i, i, i) for i in 1:3]
1126+
ev = [ones(Int,1,2), ones(Int,2,3)]
1127+
S = SymTridiagonal(dv, ev)
1128+
@test S == Array{Matrix{Int}}(S)
1129+
end
1130+
11021131
@testset "convert to Tridiagonal/SymTridiagonal" begin
11031132
@testset "Tridiagonal" begin
11041133
for M in [diagm(0 => [1,2,3], 1=>[4,5]),

0 commit comments

Comments
 (0)