Skip to content

Commit 44f3bb3

Browse files
authored
Make transposes of StridedArrays strided (#29135)
* Make transposes of StridedArrays strided but only in cases where the transpose is actually stored in memory.
1 parent 39d7a18 commit 44f3bb3

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

stdlib/LinearAlgebra/src/adjtrans.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,16 @@ IndexStyle(::Type{<:AdjOrTransAbsMat}) = IndexCartesian()
198198
convert(::Type{Adjoint{T,S}}, A::Adjoint) where {T,S} = Adjoint{T,S}(convert(S, A.parent))
199199
convert(::Type{Transpose{T,S}}, A::Transpose) where {T,S} = Transpose{T,S}(convert(S, A.parent))
200200

201+
# Strides and pointer for transposed strided arrays — but only if the elements are actually stored in memory
202+
Base.strides(A::Adjoint{<:Real, <:StridedVector}) = (stride(A.parent, 2), stride(A.parent, 1))
203+
Base.strides(A::Transpose{<:Any, <:StridedVector}) = (stride(A.parent, 2), stride(A.parent, 1))
204+
# For matrices it's slightly faster to use reverse and avoid calling stride twice
205+
Base.strides(A::Adjoint{<:Real, <:StridedMatrix}) = reverse(strides(A.parent))
206+
Base.strides(A::Transpose{<:Any, <:StridedMatrix}) = reverse(strides(A.parent))
207+
208+
Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
209+
Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
210+
201211
# for vectors, the semantics of the wrapped and unwrapped types differ
202212
# so attempt to maintain both the parent and wrapper type insofar as possible
203213
similar(A::AdjOrTransAbsVec) = wrapperop(A)(similar(A.parent))

stdlib/LinearAlgebra/test/adjtrans.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,31 @@ end
483483
"$t of "*sprint((io, t) -> show(io, MIME"text/plain"(), t), parent(Fop))
484484
end
485485

486+
@testset "strided transposes" begin
487+
for t in (Adjoint, Transpose)
488+
@test strides(t(rand(3))) == (3, 1)
489+
@test strides(t(rand(3,2))) == (3, 1)
490+
@test strides(t(view(rand(3, 2), :))) == (6, 1)
491+
@test strides(t(view(rand(3, 2), :, 1:2))) == (3, 1)
492+
493+
A = rand(3)
494+
@test pointer(t(A)) === pointer(A)
495+
B = rand(3,1)
496+
@test pointer(t(B)) === pointer(B)
497+
end
498+
@test_throws MethodError strides(Adjoint(rand(3) .+ rand(3).*im))
499+
@test_throws MethodError strides(Adjoint(rand(3, 2) .+ rand(3, 2).*im))
500+
@test strides(Transpose(rand(3) .+ rand(3).*im)) == (3, 1)
501+
@test strides(Transpose(rand(3, 2) .+ rand(3, 2).*im)) == (3, 1)
502+
503+
C = rand(3) .+ rand(3).*im
504+
@test_throws ErrorException pointer(Adjoint(C))
505+
@test pointer(Transpose(C)) === pointer(C)
506+
D = rand(3,2) .+ rand(3,2).*im
507+
@test_throws ErrorException pointer(Adjoint(D))
508+
@test pointer(Transpose(D)) === pointer(D)
509+
end
510+
486511
const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
487512
isdefined(Main, :OffsetArrays) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "OffsetArrays.jl"))
488513
using .Main.OffsetArrays

0 commit comments

Comments
 (0)