Skip to content

Commit 29c9ea0

Browse files
authored
Support negative strides in BLAS.gemv! (#41513)
* Support negative strides in `BLAS.gemv!` * Preserve X and Y during ccall
1 parent 12d364e commit 29c9ea0

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -664,13 +664,19 @@ for (fname, elty) in ((:dgemv_,:Float64),
664664
throw(DimensionMismatch("the transpose of A has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
665665
end
666666
chkstride1(A)
667-
ccall((@blasfunc($fname), libblastrampoline), Cvoid,
667+
lda = stride(A,2)
668+
lda >= max(1, size(A,1)) || error("`stride(A,2)` must be at least `max(1, size(A,1))`")
669+
sX = stride(X,1)
670+
pX = pointer(X, sX > 0 ? firstindex(X) : lastindex(X))
671+
sY = stride(Y,1)
672+
pY = pointer(Y, sY > 0 ? firstindex(X) : lastindex(X))
673+
GC.@preserve X Y ccall((@blasfunc($fname), libblastrampoline), Cvoid,
668674
(Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ref{$elty},
669675
Ptr{$elty}, Ref{BlasInt}, Ptr{$elty}, Ref{BlasInt},
670676
Ref{$elty}, Ptr{$elty}, Ref{BlasInt}, Clong),
671677
trans, size(A,1), size(A,2), alpha,
672-
A, max(1,stride(A,2)), X, stride(X,1),
673-
beta, Y, stride(Y,1), 1)
678+
A, lda, pX, sX,
679+
beta, pY, sY, 1)
674680
Y
675681
end
676682
function gemv(trans::AbstractChar, alpha::($elty), A::AbstractMatrix{$elty}, X::AbstractVector{$elty})

stdlib/LinearAlgebra/test/blas.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,41 @@ Random.seed!(100)
370370
@test all(o4cp .== z4)
371371
@test all(BLAS.gemv('N', U4, o4) .== v41)
372372
@test all(BLAS.gemv('N', U4, o4) .== v41)
373+
@testset "non-standard strides" begin
374+
if elty <: Complex
375+
A = elty[1+2im 3+4im 5+6im 7+8im; 2+3im 4+5im 6+7im 8+9im; 3+4im 5+6im 7+8im 9+10im]
376+
v = elty[1+2im, 2+3im, 3+4im, 4+5im]
377+
dest = view(ones(elty, 5), 4:-2:2)
378+
@test BLAS.gemv!('N', elty(2), view(A, 2:3, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[-35+178im, -39+202im]
379+
@test BLAS.gemv('N', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[15-41im, 17-49im]
380+
@test BLAS.gemv('N', view(A, 1:0, 1:2), view(v, 1:2)) == elty[]
381+
dest = view(ones(elty, 5), 4:-2:2)
382+
@test BLAS.gemv!('T', elty(2), view(A, 2:3, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[-29+124im, -45+220im]
383+
@test BLAS.gemv('T', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[14-38im, 18-54im]
384+
@test BLAS.gemv('T', view(A, 2:3, 2:1), view(v, 1:2)) == elty[]
385+
dest = view(ones(elty, 5), 4:-2:2)
386+
@test BLAS.gemv!('C', elty(2), view(A, 2:3, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[131+8im, 227+24im]
387+
@test BLAS.gemv('C', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[-40-6im, -56-10im]
388+
@test BLAS.gemv('C', view(A, 2:3, 2:1), view(v, 1:2)) == elty[]
389+
else
390+
A = elty[1 2 3 4; 5 6 7 8; 9 10 11 12]
391+
v = elty[1, 2, 3, 4]
392+
dest = view(ones(elty, 5), 4:-2:2)
393+
@test BLAS.gemv!('N', elty(2), view(A, 2:3, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[79, 119]
394+
@test BLAS.gemv('N', elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[-19, -31]
395+
@test BLAS.gemv('N', view(A, 1:0, 1:2), view(v, 1:2)) == elty[]
396+
for trans = ('T', 'C')
397+
dest = view(ones(elty, 5), 4:-2:2)
398+
@test BLAS.gemv!(trans, elty(2), view(A, 2:3, 2:2:4), view(v, 1:3:4), elty(3), dest) == elty[95, 115]
399+
@test BLAS.gemv(trans, elty(-1), view(A, 2:3, 2:3), view(v, 2:-1:1)) == elty[-22, -25]
400+
@test BLAS.gemv(trans, view(A, 2:3, 2:1), view(v, 1:2)) == elty[]
401+
end
402+
end
403+
for trans = ('N', 'T', 'C')
404+
@test_throws ErrorException BLAS.gemv(trans, view(A, 1:2:3, 1:2), view(v, 1:2))
405+
@test_throws ErrorException BLAS.gemv(trans, view(A, 1:2, 2:-1:1), view(v, 1:2))
406+
end
407+
end
373408
end
374409
@testset "gemm" begin
375410
@test all(BLAS.gemm('N', 'N', I4, I4) .== I4)
@@ -459,6 +494,7 @@ Base.setindex!(A::WrappedArray{T, N}, v, I::Vararg{Int, N}) where {T, N} = setin
459494
Base.unsafe_convert(::Type{Ptr{T}}, A::WrappedArray{T}) where T = Base.unsafe_convert(Ptr{T}, A.A)
460495

461496
Base.strides(A::WrappedArray) = strides(A.A)
497+
Base.elsize(::Type{WrappedArray{T,N}}) where {T,N} = Base.elsize(Array{T,N})
462498

463499
@testset "strided interface adjtrans" begin
464500
x = WrappedArray([1, 2, 3, 4])

0 commit comments

Comments
 (0)