Skip to content

Commit de85dea

Browse files
dkarraschmcabbott
andauthored
Optimize permutedims on *diagonals (#43792)
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
1 parent f15b4c3 commit de85dea

File tree

6 files changed

+31
-3
lines changed

6 files changed

+31
-3
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
1212
asin, asinh, atan, atanh, axes, big, broadcast, ceil, cis, conj, convert, copy, copyto!, cos,
1313
cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat,
1414
getproperty, imag, inv, isapprox, isequal, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims,
15-
one, oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
15+
one, oneunit, parent, permutedims, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
1616
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
1717
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec, zero
1818
using Base: IndexLinear, promote_eltype, promote_op, promote_typeof,

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,11 @@ adjoint(B::Bidiagonal) = Adjoint(B)
255255
transpose(B::Bidiagonal) = Transpose(B)
256256
adjoint(B::Bidiagonal{<:Real}) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? :L : :U)
257257
transpose(B::Bidiagonal{<:Number}) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? :L : :U)
258+
permutedims(B::Bidiagonal) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? 'L' : 'U')
259+
function permutedims(B::Bidiagonal, perm)
260+
Base.checkdims_perm(B, B, perm)
261+
NTuple{2}(perm) == (2, 1) ? permutedims(B) : B
262+
end
258263
function Base.copy(aB::Adjoint{<:Any,<:Bidiagonal})
259264
B = aB.parent
260265
return Bidiagonal(map(x -> copy.(adjoint.(x)), (B.dv, B.ev))..., B.uplo == 'U' ? :L : :U)

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,8 @@ transpose(D::Diagonal{<:Number}) = D
595595
transpose(D::Diagonal) = Diagonal(transpose.(D.diag))
596596
adjoint(D::Diagonal{<:Number}) = conj(D)
597597
adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag))
598-
Base.permutedims(D::Diagonal) = D
599-
Base.permutedims(D::Diagonal, perm) = (Base.checkdims_perm(D, D, perm); D)
598+
permutedims(D::Diagonal) = D
599+
permutedims(D::Diagonal, perm) = (Base.checkdims_perm(D, D, perm); D)
600600

601601
function diag(D::Diagonal{T}, k::Integer=0) where T
602602
# every branch call similar(..., ::Int) to make sure the

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ end
162162
transpose(S::SymTridiagonal) = S
163163
adjoint(S::SymTridiagonal{<:Real}) = S
164164
adjoint(S::SymTridiagonal) = Adjoint(S)
165+
permutedims(S::SymTridiagonal) = S
166+
function permutedims(S::SymTridiagonal, perm)
167+
Base.checkdims_perm(S, S, perm)
168+
NTuple{2}(perm) == (2, 1) ? permutedims(S) : S
169+
end
165170
Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...)
166171

167172
ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S))
@@ -597,6 +602,11 @@ adjoint(S::Tridiagonal) = Adjoint(S)
597602
transpose(S::Tridiagonal) = Transpose(S)
598603
adjoint(S::Tridiagonal{<:Real}) = Tridiagonal(S.du, S.d, S.dl)
599604
transpose(S::Tridiagonal{<:Number}) = Tridiagonal(S.du, S.d, S.dl)
605+
permutedims(T::Tridiagonal) = Tridiagonal(T.du, T.d, T.dl)
606+
function permutedims(T::Tridiagonal, perm)
607+
Base.checkdims_perm(T, T, perm)
608+
NTuple{2}(perm) == (2, 1) ? permutedims(T) : T
609+
end
600610
Base.copy(aS::Adjoint{<:Any,<:Tridiagonal}) = (S = aS.parent; Tridiagonal(map(x -> copy.(adjoint.(x)), (S.du, S.d, S.dl))...))
601611
Base.copy(tS::Transpose{<:Any,<:Tridiagonal}) = (S = tS.parent; Tridiagonal(map(x -> copy.(transpose.(x)), (S.du, S.d, S.dl))...))
602612

stdlib/LinearAlgebra/test/bidiag.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ Random.seed!(1)
135135
@test func(func(T)) == T
136136
end
137137

138+
@testset "permutedims(::Bidiagonal)" begin
139+
@test permutedims(permutedims(T)) === T
140+
@test permutedims(T) == transpose.(transpose(T))
141+
@test permutedims(T, [1, 2]) === T
142+
@test permutedims(T, (2, 1)) == permutedims(T)
143+
end
144+
138145
@testset "triu and tril" begin
139146
zerosdv = zeros(elty, length(dv))
140147
zerosev = zeros(elty, length(ev))

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,12 @@ end
266266
@test func(func(A)) == A
267267
end
268268
end
269+
@testset "permutedims(::[Sym]Tridiagonal)" begin
270+
@test permutedims(permutedims(A)) === A
271+
@test permutedims(A) == transpose.(transpose(A))
272+
@test permutedims(A, [1, 2]) === A
273+
@test permutedims(A, (2, 1)) == permutedims(A)
274+
end
269275
if elty != Int
270276
@testset "Simple unary functions" begin
271277
for func in (det, inv)

0 commit comments

Comments
 (0)