Skip to content

Commit a3cd5b0

Browse files
Sahiti004lazarusA
authored andcommitted
Implement reverse for more Matrix variants (JuliaLang#54231)
1 parent 1e1af76 commit a3cd5b0

File tree

12 files changed

+131
-0
lines changed

12 files changed

+131
-0
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ end
184184
return x
185185
end
186186

187+
Base._reverse(A::Bidiagonal, dims) = reverse!(Matrix(A); dims)
188+
Base._reverse(A::Bidiagonal, ::Colon) = Bidiagonal(reverse(A.dv), reverse(A.ev), A.uplo == 'U' ? :L : :U)
189+
187190
## structured matrix methods ##
188191
function Base.replace_in_print_matrix(A::Bidiagonal,i::Integer,j::Integer,s::AbstractString)
189192
if A.uplo == 'U'

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ Diagonal{T}(D::Diagonal) where {T} = Diagonal{T}(D.diag)
114114
AbstractMatrix{T}(D::Diagonal) where {T} = Diagonal{T}(D)
115115
AbstractMatrix{T}(D::Diagonal{T}) where {T} = copy(D)
116116
Matrix(D::Diagonal{T}) where {T} = Matrix{promote_type(T, typeof(zero(T)))}(D)
117+
Matrix(D::Diagonal{Any}) = Matrix{Any}(D)
117118
Array(D::Diagonal{T}) where {T} = Matrix(D)
118119
function Matrix{T}(D::Diagonal) where {T}
119120
B = Matrix{T}(undef, size(D))
@@ -187,6 +188,10 @@ end
187188
diagzero(::Diagonal{T}, i, j) where {T} = zero(T)
188189
diagzero(D::Diagonal{<:AbstractMatrix{T}}, i, j) where {T} = zeros(T, size(D.diag[i], 1), size(D.diag[j], 2))
189190

191+
Base._reverse(A::Diagonal, dims) = reverse!(Matrix(A); dims)
192+
Base._reverse(A::Diagonal, ::Colon) = Diagonal(reverse(A.diag))
193+
Base._reverse!(A::Diagonal, ::Colon) = (reverse!(A.diag); A)
194+
190195
function setindex!(D::Diagonal, v, i::Int, j::Int)
191196
@boundscheck checkbounds(D, i, j)
192197
if i == j

stdlib/LinearAlgebra/src/hessenberg.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ Base.isassigned(H::UpperHessenberg, i::Int, j::Int) =
9090
Base.@propagate_inbounds getindex(H::UpperHessenberg{T}, i::Int, j::Int) where {T} =
9191
i <= j+1 ? convert(T, H.data[i,j]) : zero(T)
9292

93+
Base._reverse(A::UpperHessenberg, dims) = reverse!(Matrix(A); dims)
94+
9395
Base.@propagate_inbounds function setindex!(A::UpperHessenberg, x, i::Integer, j::Integer)
9496
if i > j+1
9597
x == 0 || throw(ArgumentError("cannot set index in the lower triangular part " *

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,11 +254,17 @@ end
254254
end
255255
end
256256

257+
Base._reverse(A::Symmetric, dims::Integer) = reverse!(Matrix(A); dims)
258+
Base._reverse(A::Symmetric, ::Colon) = Symmetric(reverse(A.data), A.uplo == 'U' ? :L : :U)
259+
257260
@propagate_inbounds function setindex!(A::Symmetric, v, i::Integer, j::Integer)
258261
i == j || throw(ArgumentError("Cannot set a non-diagonal index in a symmetric matrix"))
259262
setindex!(A.data, v, i, j)
260263
end
261264

265+
Base._reverse(A::Hermitian, dims) = reverse!(Matrix(A); dims)
266+
Base._reverse(A::Hermitian, ::Colon) = Hermitian(reverse(A.data), A.uplo == 'U' ? :L : :U)
267+
262268
@propagate_inbounds function setindex!(A::Hermitian, v, i::Integer, j::Integer)
263269
if i != j
264270
throw(ArgumentError("Cannot set a non-diagonal index in a Hermitian matrix"))

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ Base.isstored(A::UpperTriangular, i::Int, j::Int) =
246246
@propagate_inbounds getindex(A::UpperTriangular, i::Int, j::Int) =
247247
i <= j ? A.data[i,j] : _zero(A.data,j,i)
248248

249+
Base._reverse(A::UpperOrUnitUpperTriangular, dims::Integer) = reverse!(Matrix(A); dims)
250+
Base._reverse(A::UpperTriangular, ::Colon) = LowerTriangular(reverse(A.data))
251+
Base._reverse(A::UnitUpperTriangular, ::Colon) = UnitLowerTriangular(reverse(A.data))
252+
249253
@propagate_inbounds function setindex!(A::UpperTriangular, x, i::Integer, j::Integer)
250254
if i > j
251255
iszero(x) || throw(ArgumentError("cannot set index in the lower triangular part " *
@@ -269,6 +273,10 @@ end
269273
return A
270274
end
271275

276+
Base._reverse(A::LowerOrUnitLowerTriangular, dims) = reverse!(Matrix(A); dims)
277+
Base._reverse(A::LowerTriangular, ::Colon) = UpperTriangular(reverse(A.data))
278+
Base._reverse(A::UnitLowerTriangular, ::Colon) = UnitUpperTriangular(reverse(A.data))
279+
272280
@propagate_inbounds function setindex!(A::LowerTriangular, x, i::Integer, j::Integer)
273281
if i < j
274282
iszero(x) || throw(ArgumentError("cannot set index in the upper triangular part " *

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@ end
461461
end
462462
end
463463

464+
Base._reverse(A::SymTridiagonal, dims) = reverse!(Matrix(A); dims)
465+
Base._reverse(A::SymTridiagonal, dims::Colon) = SymTridiagonal(reverse(A.dv), reverse(A.ev))
466+
Base._reverse!(A::SymTridiagonal, dims::Colon) = (reverse!(A.dv); reverse!(A.ev); A)
467+
464468
@inline function setindex!(A::SymTridiagonal, x, i::Integer, j::Integer)
465469
@boundscheck checkbounds(A, i, j)
466470
if i == j
@@ -698,6 +702,18 @@ end
698702
end
699703
end
700704

705+
Base._reverse(A::Tridiagonal, dims) = reverse!(Matrix(A); dims)
706+
Base._reverse(A::Tridiagonal, dims::Colon) = Tridiagonal(reverse(A.du), reverse(A.d), reverse(A.dl))
707+
function Base._reverse!(A::Tridiagonal, dims::Colon)
708+
n = length(A.du) # == length(A.dl), & always 1-based
709+
# reverse and swap A.dl and A.du:
710+
@inbounds for i in 1:n
711+
A.dl[i], A.du[n+1-i] = A.du[n+1-i], A.dl[i]
712+
end
713+
reverse!(A.d)
714+
return A
715+
end
716+
701717
@inline function setindex!(A::Tridiagonal, x, i::Integer, j::Integer)
702718
@boundscheck checkbounds(A, i, j)
703719
if i == j

stdlib/LinearAlgebra/test/bidiag.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,18 @@ end
907907
@test mul!(C1, B, sv, 1, 2) == mul!(C2, B, v, 1 ,2)
908908
end
909909

910+
@testset "Reverse operation on Bidiagonal" begin
911+
n = 5
912+
d = randn(n)
913+
e = randn(n - 1)
914+
for uplo in (:U, :L)
915+
B = Bidiagonal(d, e, uplo)
916+
@test reverse(B, dims=1) == reverse(Matrix(B), dims=1)
917+
@test reverse(B, dims=2) == reverse(Matrix(B), dims=2)
918+
@test reverse(B)::Bidiagonal == reverse(Matrix(B))
919+
end
920+
end
921+
910922
@testset "Matrix conversion for non-numeric" begin
911923
B = Bidiagonal(fill(Diagonal([1,3]), 3), fill(Diagonal([1,3]), 2), :U)
912924
M = Matrix{eltype(B)}(B)

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ Random.seed!(1)
6161
@test_throws MethodError convert(Diagonal, [1,2,3,4])
6262
@test_throws DimensionMismatch convert(Diagonal, [1 2 3 4])
6363
@test_throws InexactError convert(Diagonal, ones(2,2))
64+
65+
# Test reversing
66+
# Test reversing along rows
67+
@test reverse(D, dims=1) == reverse(Matrix(D), dims=1)
68+
69+
# Test reversing along columns
70+
@test reverse(D, dims=2) == reverse(Matrix(D), dims=2)
71+
72+
# Test reversing the entire matrix
73+
@test reverse(D)::Diagonal == reverse(Matrix(D)) == reverse!(copy(D))
6474
end
6575

6676
@testset "Basic properties" begin
@@ -609,6 +619,13 @@ end
609619
end
610620
end
611621

622+
@testset "Test reverse" begin
623+
D = Diagonal(randn(5))
624+
@test reverse(D, dims=1) == reverse(Matrix(D), dims=1)
625+
@test reverse(D, dims=2) == reverse(Matrix(D), dims=2)
626+
@test reverse(D)::Diagonal == reverse(Matrix(D))
627+
end
628+
612629
@testset "inverse" begin
613630
for d in Any[randn(n), Int[], [1, 2, 3], [1im, 2im, 3im], [1//1, 2//1, 3//1], [1+1im//1, 2//1, 3im//1]]
614631
D = Diagonal(d)

stdlib/LinearAlgebra/test/hessenberg.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,13 @@ let n = 10
202202
end
203203
end
204204

205+
@testset "Reverse operation on UpperHessenberg" begin
206+
A = UpperHessenberg(randn(5, 5))
207+
@test reverse(A, dims=1) == reverse(Matrix(A), dims=1)
208+
@test reverse(A, dims=2) == reverse(Matrix(A), dims=2)
209+
@test reverse(A) == reverse(Matrix(A))
210+
end
211+
205212
@testset "hessenberg(::AbstractMatrix)" begin
206213
n = 10
207214
A = Tridiagonal(rand(n-1), rand(n), rand(n-1))

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,25 @@ end
559559
end
560560
end
561561

562+
@testset "Reverse operation on Symmetric" begin
563+
for uplo in (:U, :L)
564+
A = Symmetric(randn(5, 5), uplo)
565+
@test reverse(A, dims=1) == reverse(Matrix(A), dims=1)
566+
@test reverse(A, dims=2) == reverse(Matrix(A), dims=2)
567+
@test reverse(A)::Symmetric == reverse(Matrix(A))
568+
end
569+
end
570+
571+
@testset "Reverse operation on Hermitian" begin
572+
for uplo in (:U, :L)
573+
A = Hermitian(randn(ComplexF64, 5, 5), uplo)
574+
@test reverse(A, dims=1) == reverse(Matrix(A), dims=1)
575+
@test reverse(A, dims=2) == reverse(Matrix(A), dims=2)
576+
@test reverse(A)::Hermitian == reverse(Matrix(A))
577+
end
578+
end
579+
580+
562581
# bug identified in PR #52318: dot products of quaternionic Hermitian matrices,
563582
# or any number type where conj(a)*conj(b) ≠ conj(a*b):
564583
@testset "dot Hermitian quaternion #52318" begin

0 commit comments

Comments
 (0)