Skip to content

Commit 8e5ab39

Browse files
Support for adjoint and transpose of strided inputs (non-contiguous views) (#452)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 86fb4f2 commit 8e5ab39

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

lib/GPUArraysCore/src/GPUArraysCore.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ using Adapt
66
## essential types
77

88
export AbstractGPUArray, AbstractGPUVector, AbstractGPUMatrix, AbstractGPUVecOrMat,
9-
WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle
9+
WrappedGPUArray, AnyGPUArray, AbstractGPUArrayStyle,
10+
AnyGPUArray, AnyGPUVector, AnyGPUMatrix
1011

1112
"""
1213
AbstractGPUArray{T, N} <: DenseArray{T, N}
@@ -24,6 +25,8 @@ const AbstractGPUVecOrMat{T} = Union{AbstractGPUArray{T, 1}, AbstractGPUArray{T,
2425
# convenience aliases for working with wrapped arrays
2526
const WrappedGPUArray{T,N} = WrappedArray{T,N,AbstractGPUArray,AbstractGPUArray{T,N}}
2627
const AnyGPUArray{T,N} = Union{AbstractGPUArray{T,N}, WrappedGPUArray{T,N}}
28+
const AnyGPUVector{T} = AnyGPUArray{T, 1}
29+
const AnyGPUMatrix{T} = AnyGPUArray{T, 2}
2730

2831
## broadcasting
2932

src/host/linalg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ function LinearAlgebra.adjoint!(B::AbstractGPUMatrix, A::AbstractGPUVector)
2929
B
3030
end
3131

32-
LinearAlgebra.transpose!(B::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(transpose, B, A)
33-
LinearAlgebra.adjoint!(B::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(adjoint, B, A)
34-
function transpose_f!(f, B::AbstractGPUMatrix{T}, A::AbstractGPUMatrix{T}) where T
32+
LinearAlgebra.transpose!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(transpose, B, A)
33+
LinearAlgebra.adjoint!(B::AnyGPUArray, A::AnyGPUArray) = transpose_f!(adjoint, B, A)
34+
function transpose_f!(f, B::AnyGPUMatrix{T}, A::AnyGPUMatrix{T}) where T
3535
axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) || throw(DimensionMismatch(string(f)))
3636
gpu_call(B, A) do ctx, B, A
3737
idx = @cartesianidx A

0 commit comments

Comments
 (0)