Skip to content

Commit f2d16e4

Browse files
authored
Add diagview to obtain a view along a diagonal (#56175)
A function to obtain a view of a diagonal of a matrix is useful, and this is clearly being used widely within `LinearAlgebra`. The implementation here iterates according to the `IndexStyle` of the array: ```julia julia> using LinearAlgebra julia> A = reshape(1:9, 3, 3) 3×3 reshape(::UnitRange{Int64}, 3, 3) with eltype Int64: 1 4 7 2 5 8 3 6 9 julia> diagview(A,1) 2-element view(::UnitRange{Int64}, 4:4:8) with eltype Int64: 4 8 julia> T = Tridiagonal(1:3, 3:6, 4:6) 4×4 Tridiagonal{Int64, UnitRange{Int64}}: 3 4 ⋅ ⋅ 1 4 5 ⋅ ⋅ 2 5 6 ⋅ ⋅ 3 6 julia> diagview(T,1) 3-element view(::Tridiagonal{Int64, UnitRange{Int64}}, StepRangeLen(CartesianIndex(1, 2), CartesianIndex(1, 1), 3)) with eltype Int64: 4 5 6 ``` Closes JuliaLang/julia#30250
1 parent aceebea commit f2d16e4

File tree

11 files changed

+78
-33
lines changed

11 files changed

+78
-33
lines changed

src/LinearAlgebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ export
8787
diag,
8888
diagind,
8989
diagm,
90+
diagview,
9091
dot,
9192
eigen!,
9293
eigen,

src/abstractq.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,11 +456,9 @@ end
456456

457457
det(Q::QRPackedQ) = _det_tau(Q.τ)
458458
det(Q::QRCompactWYQ) =
459-
prod(i -> _det_tau(_diagview(Q.T[:, i:min(i + size(Q.T, 1), size(Q.T, 2))])),
459+
prod(i -> _det_tau(diagview(Q.T[:, i:min(i + size(Q.T, 1), size(Q.T, 2))])),
460460
1:size(Q.T, 1):size(Q.T, 2))
461461

462-
_diagview(A) = @view A[diagind(A)]
463-
464462
# Compute `det` from the number of Householder reflections. Handle
465463
# the case `Q.τ` contains zeros.
466464
_det_tau(τs::AbstractVector{<:Real}) =

src/bidiag.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ function Matrix{T}(A::Bidiagonal) where T
191191
B = Matrix{T}(undef, size(A))
192192
if haszero(T) # optimized path for types with zero(T) defined
193193
size(B,1) > 1 && fill!(B, zero(T))
194-
copyto!(view(B, diagind(B)), A.dv)
195-
copyto!(view(B, diagind(B, _offdiagind(A.uplo))), A.ev)
194+
copyto!(diagview(B), A.dv)
195+
copyto!(diagview(B, _offdiagind(A.uplo)), A.ev)
196196
else
197197
copyto!(B, A)
198198
end
@@ -570,7 +570,7 @@ end
570570
# to avoid allocations in _mul! below (#24324, #24578)
571571
_diag(A::Tridiagonal, k) = k == -1 ? A.dl : k == 0 ? A.d : A.du
572572
_diag(A::SymTridiagonal{<:Number}, k) = k == 0 ? A.dv : A.ev
573-
_diag(A::SymTridiagonal, k) = k == 0 ? view(A, diagind(A, IndexStyle(A))) : view(A, diagind(A, 1, IndexStyle(A)))
573+
_diag(A::SymTridiagonal, k) = diagview(A,k)
574574
function _diag(A::Bidiagonal, k)
575575
if k == 0
576576
return A.dv

src/dense.jl

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,35 @@ julia> diag(A,1)
290290
"""
291291
diag(A::AbstractMatrix, k::Integer=0) = A[diagind(A, k, IndexStyle(A))]
292292

293+
"""
294+
diagview(M, k::Integer=0)
295+
296+
Return a view into the `k`th diagonal of the matrix `M`.
297+
298+
See also [`diag`](@ref), [`diagind`](@ref).
299+
300+
# Examples
301+
```jldoctest
302+
julia> A = [1 2 3; 4 5 6; 7 8 9]
303+
3×3 Matrix{Int64}:
304+
1 2 3
305+
4 5 6
306+
7 8 9
307+
308+
julia> diagview(A)
309+
3-element view(::Vector{Int64}, 1:4:9) with eltype Int64:
310+
1
311+
5
312+
9
313+
314+
julia> diagview(A, 1)
315+
2-element view(::Vector{Int64}, 4:4:8) with eltype Int64:
316+
2
317+
6
318+
```
319+
"""
320+
diagview(A::AbstractMatrix, k::Integer=0) = @view A[diagind(A, k, IndexStyle(A))]
321+
293322
"""
294323
diagm(kv::Pair{<:Integer,<:AbstractVector}...)
295324
diagm(m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)
@@ -1636,13 +1665,11 @@ function pinv(A::AbstractMatrix{T}; atol::Real = 0.0, rtol::Real = (eps(real(flo
16361665
return similar(A, Tout, (n, m))
16371666
end
16381667
if isdiag(A)
1639-
indA = diagind(A)
1640-
dA = view(A, indA)
1668+
dA = diagview(A)
16411669
maxabsA = maximum(abs, dA)
16421670
tol = max(rtol * maxabsA, atol)
16431671
B = fill!(similar(A, Tout, (n, m)), 0)
1644-
indB = diagind(B)
1645-
B[indB] .= (x -> abs(x) > tol ? pinv(x) : zero(x)).(dA)
1672+
diagview(B) .= (x -> abs(x) > tol ? pinv(x) : zero(x)).(dA)
16461673
return B
16471674
end
16481675
SVD = svd(A)

src/diagonal.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ function Matrix{T}(D::Diagonal) where {T}
120120
B = Matrix{T}(undef, size(D))
121121
if haszero(T) # optimized path for types with zero(T) defined
122122
size(B,1) > 1 && fill!(B, zero(T))
123-
copyto!(view(B, diagind(B)), D.diag)
123+
copyto!(diagview(B), D.diag)
124124
else
125125
copyto!(B, D)
126126
end
@@ -1041,7 +1041,7 @@ dot(x::AbstractVector, D::Diagonal, y::AbstractVector) = _mapreduce_prod(dot, x,
10411041
dot(A::Diagonal, B::Diagonal) = dot(A.diag, B.diag)
10421042
function dot(D::Diagonal, B::AbstractMatrix)
10431043
size(D) == size(B) || throw(DimensionMismatch(lazy"Matrix sizes $(size(D)) and $(size(B)) differ"))
1044-
return dot(D.diag, view(B, diagind(B, IndexStyle(B))))
1044+
return dot(D.diag, diagview(B))
10451045
end
10461046

10471047
dot(A::AbstractMatrix, B::Diagonal) = conj(dot(B, A))

src/special.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function Tridiagonal(A::Bidiagonal)
2222
end
2323

2424
_diagview(S::SymTridiagonal{<:Number}) = S.dv
25-
_diagview(S::SymTridiagonal) = view(S, diagind(S, IndexStyle(S)))
25+
_diagview(S::SymTridiagonal) = diagview(S)
2626

2727
# conversions from SymTridiagonal to other special matrix types
2828
Diagonal(A::SymTridiagonal) = Diagonal(_diagview(A))
@@ -370,20 +370,20 @@ function copyto!(dest::BandedMatrix, src::BandedMatrix)
370370
end
371371
function _copyto_banded!(T::Tridiagonal, D::Diagonal)
372372
T.d .= D.diag
373-
T.dl .= view(D, diagind(D, -1, IndexStyle(D)))
374-
T.du .= view(D, diagind(D, 1, IndexStyle(D)))
373+
T.dl .= diagview(D, -1)
374+
T.du .= diagview(D, 1)
375375
return T
376376
end
377377
function _copyto_banded!(SymT::SymTridiagonal, D::Diagonal)
378378
issymmetric(D) || throw(ArgumentError("cannot copy a non-symmetric Diagonal matrix to a SymTridiagonal"))
379379
SymT.dv .= D.diag
380380
_ev = _evview(SymT)
381-
_ev .= view(D, diagind(D, 1, IndexStyle(D)))
381+
_ev .= diagview(D, 1)
382382
return SymT
383383
end
384384
function _copyto_banded!(B::Bidiagonal, D::Diagonal)
385385
B.dv .= D.diag
386-
B.ev .= view(D, diagind(D, B.uplo == 'U' ? 1 : -1, IndexStyle(D)))
386+
B.ev .= diagview(D, _offdiagind(B.uplo))
387387
return B
388388
end
389389
function _copyto_banded!(D::Diagonal, B::Bidiagonal)
@@ -411,10 +411,10 @@ function _copyto_banded!(T::Tridiagonal, B::Bidiagonal)
411411
T.d .= B.dv
412412
if B.uplo == 'U'
413413
T.du .= B.ev
414-
T.dl .= view(B, diagind(B, -1, IndexStyle(B)))
414+
T.dl .= diagview(B,-1)
415415
else
416416
T.dl .= B.ev
417-
T.du .= view(B, diagind(B, 1, IndexStyle(B)))
417+
T.du .= diagview(B, 1)
418418
end
419419
return T
420420
end

src/triangular.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ end
188188
function full(A::UnitUpperOrUnitLowerTriangular)
189189
isupper = A isa UnitUpperTriangular
190190
Ap = _triangularize(A)(parent(A), isupper ? 1 : -1)
191-
Ap[diagind(Ap, IndexStyle(Ap))] = @view A[diagind(A, IndexStyle(A))]
191+
diagview(Ap) .= diagview(A)
192192
return Ap
193193
end
194194

@@ -400,20 +400,20 @@ function tril!(A::UnitUpperTriangular{T}, k::Integer=0) where {T}
400400
return UpperTriangular(A.data)
401401
elseif k == 0
402402
fill!(A.data, zero(T))
403-
for i in diagind(A)
403+
for i in diagind(A.data, IndexStyle(A.data))
404404
A.data[i] = oneunit(T)
405405
end
406406
return UpperTriangular(A.data)
407407
else
408-
for i in diagind(A)
408+
for i in diagind(A.data, IndexStyle(A.data))
409409
A.data[i] = oneunit(T)
410410
end
411411
return UpperTriangular(tril!(A.data,k))
412412
end
413413
end
414414

415415
function triu!(A::UnitUpperTriangular, k::Integer=0)
416-
for i in diagind(A.data)
416+
for i in diagind(A.data, IndexStyle(A.data))
417417
A.data[i] = oneunit(eltype(A))
418418
end
419419
return triu!(UpperTriangular(A.data), k)
@@ -448,20 +448,20 @@ function triu!(A::UnitLowerTriangular{T}, k::Integer=0) where T
448448
return LowerTriangular(A.data)
449449
elseif k == 0
450450
fill!(A.data, zero(T))
451-
for i in diagind(A)
451+
for i in diagind(A.data, IndexStyle(A.data))
452452
A.data[i] = oneunit(T)
453453
end
454454
return LowerTriangular(A.data)
455455
else
456-
for i in diagind(A)
456+
for i in diagind(A.data, IndexStyle(A.data))
457457
A.data[i] = oneunit(T)
458458
end
459459
return LowerTriangular(triu!(A.data, k))
460460
end
461461
end
462462

463463
function tril!(A::UnitLowerTriangular, k::Integer=0)
464-
for i in diagind(A.data)
464+
for i in diagind(A.data, IndexStyle(A.data))
465465
A.data[i] = oneunit(eltype(A))
466466
end
467467
return tril!(LowerTriangular(A.data), k)
@@ -2041,7 +2041,7 @@ function _find_params_log_quasitriu!(A)
20412041

20422042
# Find s0, the smallest s such that the ρ(triu(A)^(1/2^s) - I) ≤ theta[tmax], where ρ(X)
20432043
# is the spectral radius of X
2044-
d = complex.(@view(A[diagind(A)]))
2044+
d = complex.(diagview(A))
20452045
dm1 = d .- 1
20462046
s = 0
20472047
while norm(dm1, Inf) > theta[tmax] && s < maxsqrt

src/tridiag.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -612,9 +612,9 @@ function Matrix{T}(M::Tridiagonal) where {T}
612612
A = Matrix{T}(undef, size(M))
613613
if haszero(T) # optimized path for types with zero(T) defined
614614
size(A,1) > 2 && fill!(A, zero(T))
615-
copyto!(view(A, diagind(A)), M.d)
616-
copyto!(view(A, diagind(A,1)), M.du)
617-
copyto!(view(A, diagind(A,-1)), M.dl)
615+
copyto!(diagview(A), M.d)
616+
copyto!(diagview(A,1), M.du)
617+
copyto!(diagview(A,-1), M.dl)
618618
else
619619
copyto!(A, M)
620620
end
@@ -1092,7 +1092,7 @@ function show(io::IO, T::Tridiagonal)
10921092
end
10931093
function show(io::IO, S::SymTridiagonal)
10941094
print(io, "SymTridiagonal(")
1095-
show(io, eltype(S) <: Number ? S.dv : view(S, diagind(S, IndexStyle(S))))
1095+
show(io, _diagview(S))
10961096
print(io, ", ")
10971097
show(io, S.ev)
10981098
print(io, ")")

src/uniformscaling.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ end
201201
function (+)(A::Hermitian, J::UniformScaling{<:Complex})
202202
TS = Base.promote_op(+, eltype(A), typeof(J))
203203
B = copytri!(copymutable_oftype(parent(A), TS), A.uplo, true)
204-
for i in diagind(B)
204+
for i in diagind(B, IndexStyle(B))
205205
B[i] = A[i] + J
206206
end
207207
return B
@@ -211,7 +211,7 @@ function (-)(J::UniformScaling{<:Complex}, A::Hermitian)
211211
TS = Base.promote_op(+, eltype(A), typeof(J))
212212
B = copytri!(copymutable_oftype(parent(A), TS), A.uplo, true)
213213
B .= .-B
214-
for i in diagind(B)
214+
for i in diagind(B, IndexStyle(B))
215215
B[i] = J - A[i]
216216
end
217217
return B

test/dense.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,15 @@ end
10291029
@test diag(zeros(0,1),2) == []
10301030
end
10311031

1032+
@testset "diagview" begin
1033+
for sz in ((3,3), (3,5), (5,3))
1034+
A = rand(sz...)
1035+
for k in -5:5
1036+
@test diagview(A,k) == diag(A,k)
1037+
end
1038+
end
1039+
end
1040+
10321041
@testset "issue #39857" begin
10331042
@test lyap(1.0+2.0im, 3.0+4.0im) == -1.5 - 2.0im
10341043
end

0 commit comments

Comments
 (0)