Skip to content

Commit 39ee3af

Browse files
authored
Make adjoint/transpose for Diagonal/Bidiagonal lazy (#1368)
Currently, `adjoint` and `transpose` for a block `Diagonal` and a block `Bidiagonal` matrix allocates the diagonals. This was not intentional, and we may use lazy wrappers instead that ensures that the result shares memory with the original array, which is what the documentation states. After this PR, mutating the `adjoint` will automatically mutate the parent array.
1 parent 5f449f5 commit 39ee3af

File tree

5 files changed

+49
-11
lines changed

5 files changed

+49
-11
lines changed

src/adjtrans.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,3 +552,20 @@ conj(A::Adjoint) = transpose(A.parent)
552552
function Base.replace_in_print_matrix(A::AdjOrTrans,i::Integer,j::Integer,s::AbstractString)
553553
Base.replace_in_print_matrix(parent(A), j, i, s)
554554
end
555+
556+
# Special adjoint/transpose methods for Adjoint/Transpose that have been reshaped as a vector
557+
# these are used in transposing banded matrices by forwarding the operation to the bands
558+
"""
559+
_vectranspose(A::AbstractVector)::AbstractVector
560+
561+
Compute `vec(transpose(A))`, but avoid an allocating reshape if possible
562+
"""
563+
_vectranspose(A::AbstractVector) = vec(transpose(A))
564+
_vectranspose(A::Base.ReshapedArray{<:Any,1,<:TransposeAbsVec}) = transpose(parent(A))
565+
"""
566+
_vecadjoint(A::AbstractVector)::AbstractVector
567+
568+
Compute `vec(adjoint(A))`, but avoid an allocating reshape if possible
569+
"""
570+
_vecadjoint(A::AbstractVector) = vec(adjoint(A))
571+
_vecadjoint(A::Base.ReshapedArray{<:Any,1,<:AdjointAbsVec}) = adjoint(parent(A))

src/bidiag.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,8 @@ for func in (:conj, :copy, :real, :imag)
308308
end
309309
isreal(M::Bidiagonal) = isreal(M.dv) && isreal(M.ev)
310310

311-
adjoint(B::Bidiagonal{<:Number}) = Bidiagonal(vec(adjoint(B.dv)), vec(adjoint(B.ev)), B.uplo == 'U' ? :L : :U)
312-
adjoint(B::Bidiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) =
313-
Bidiagonal(adjoint(parent(B.dv)), adjoint(parent(B.ev)), B.uplo == 'U' ? :L : :U)
314-
adjoint(B::Bidiagonal) = Bidiagonal(adjoint.(B.dv), adjoint.(B.ev), B.uplo == 'U' ? :L : :U)
315-
transpose(B::Bidiagonal{<:Number}) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? :L : :U)
316-
transpose(B::Bidiagonal) = Bidiagonal(transpose.(B.dv), transpose.(B.ev), B.uplo == 'U' ? :L : :U)
311+
adjoint(B::Bidiagonal) = Bidiagonal(_vecadjoint(B.dv), _vecadjoint(B.ev), B.uplo == 'U' ? :L : :U)
312+
transpose(B::Bidiagonal) = Bidiagonal(_vectranspose(B.dv), _vectranspose(B.ev), B.uplo == 'U' ? :L : :U)
317313
permutedims(B::Bidiagonal) = Bidiagonal(B.dv, B.ev, B.uplo == 'U' ? 'L' : 'U')
318314
function permutedims(B::Bidiagonal, perm)
319315
Base.checkdims_perm(axes(B), axes(B), perm)

src/diagonal.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -906,11 +906,8 @@ end
906906
end
907907

908908
conj(D::Diagonal) = Diagonal(conj(D.diag))
909-
transpose(D::Diagonal{<:Number}) = D
910-
transpose(D::Diagonal) = Diagonal(transpose.(D.diag))
911-
adjoint(D::Diagonal{<:Number}) = Diagonal(vec(adjoint(D.diag)))
912-
adjoint(D::Diagonal{<:Number,<:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = Diagonal(adjoint(parent(D.diag)))
913-
adjoint(D::Diagonal) = Diagonal(adjoint.(D.diag))
909+
transpose(D::Diagonal) = Diagonal(_vectranspose(D.diag))
910+
adjoint(D::Diagonal) = Diagonal(_vecadjoint(D.diag))
914911
permutedims(D::Diagonal) = D
915912
permutedims(D::Diagonal, perm) = (Base.checkdims_perm(axes(D), axes(D), perm); D)
916913

test/bidiag.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,4 +1240,19 @@ end
12401240
@test_throws BoundsError B[LinearAlgebra.BandIndex(0,size(B,1)+1)]
12411241
end
12421242

1243+
@testset "lazy adjtrans" begin
1244+
B = Bidiagonal(fill([1 2; 3 4], 3), fill([5 6; 7 8], 2), :U)
1245+
m = [2 4; 6 8]
1246+
for op in (transpose, adjoint)
1247+
C = op(B)
1248+
el = op(m)
1249+
C[1,1] = el
1250+
@test B[1,1] == m
1251+
C[2,1] = el
1252+
@test B[1,2] == m
1253+
@test (@allocated op(B)) == 0
1254+
@test (@allocated op(op(B))) == 0
1255+
end
1256+
end
1257+
12431258
end # module TestBidiagonal

test/diagonal.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,4 +1515,17 @@ end
15151515
@test_throws BoundsError D[LinearAlgebra.BandIndex(0,size(D,1)+1)]
15161516
end
15171517

1518+
@testset "lazy adjtrans" begin
1519+
D = Diagonal(fill([1 2; 3 4], 3))
1520+
m = [2 4; 6 8]
1521+
for op in (transpose, adjoint)
1522+
C = op(D)
1523+
el = op(m)
1524+
C[1,1] = el
1525+
@test D[1,1] == m
1526+
@test (@allocated op(D)) == 0
1527+
@test (@allocated op(op(D))) == 0
1528+
end
1529+
end
1530+
15181531
end # module TestDiagonal

0 commit comments

Comments
 (0)