Skip to content

Commit c6c838b

Browse files
authored
Support transpose/adjoint on mixed vector/matrix inputs. (#447)
1 parent b977e8e commit c6c838b

File tree

2 files changed

+38
-7
lines changed

2 files changed

+38
-7
lines changed

src/host/linalg.jl

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,44 @@
22

33
## transpose and adjoint
44

5-
function transpose_f!(f, At::AbstractGPUArray{T, 2}, A::AbstractGPUArray{T, 2}) where T
6-
gpu_call(At, A) do ctx, At, A
7-
idx = @cartesianidx A
8-
@inbounds At[idx[2], idx[1]] = f(A[idx[1], idx[2]])
5+
function LinearAlgebra.transpose!(B::AbstractGPUVector, A::AbstractGPUMatrix)
6+
axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose"))
7+
copyto!(B, A)
8+
end
9+
function LinearAlgebra.transpose!(B::AbstractGPUMatrix, A::AbstractGPUVector)
10+
axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose"))
11+
copyto!(B, A)
12+
end
13+
function LinearAlgebra.adjoint!(B::AbstractGPUVector, A::AbstractGPUMatrix)
14+
axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("adjoint"))
15+
gpu_call(B, A) do ctx, B, A
16+
idx = @linearidx B
17+
@inbounds B[idx] = adjoint(A[1, idx])
18+
return
19+
end
20+
B
21+
end
22+
function LinearAlgebra.adjoint!(B::AbstractGPUMatrix, A::AbstractGPUVector)
23+
axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("adjoint"))
24+
gpu_call(B, A) do ctx, B, A
25+
idx = @linearidx A
26+
@inbounds B[1, idx] = adjoint(A[idx])
927
return
1028
end
11-
At
29+
B
1230
end
1331

14-
LinearAlgebra.transpose!(At::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(transpose, At, A)
15-
LinearAlgebra.adjoint!(At::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(adjoint, At, A)
32+
LinearAlgebra.transpose!(B::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(transpose, B, A)
33+
LinearAlgebra.adjoint!(B::AbstractGPUArray, A::AbstractGPUArray) = transpose_f!(adjoint, B, A)
34+
function transpose_f!(f, B::AbstractGPUMatrix{T}, A::AbstractGPUMatrix{T}) where T
35+
axes(B,1) == axes(A,2) && axes(B,2) == axes(A,1) || throw(DimensionMismatch(string(f)))
36+
gpu_call(B, A) do ctx, B, A
37+
idx = @cartesianidx A
38+
@inbounds B[idx[2], idx[1]] = f(A[idx[1], idx[2]])
39+
return
40+
end
41+
B
42+
end
1643

1744
function Base.copyto!(A::AbstractGPUArray{T,N}, B::Adjoint{T, <: AbstractGPUArray{T,N}}) where {T,N}
1845
adjoint!(A, B.parent)

test/testsuite/linalg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
@testset "adjoint and transpose" begin
33
@test compare(adjoint, AT, rand(Float32, 32, 32))
44
@test compare(adjoint!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
5+
@test compare(adjoint!, AT, rand(Float32, 1, 32), rand(Float32, 32))
6+
@test compare(adjoint!, AT, rand(Float32, 32), rand(Float32, 1, 32))
57
@test compare(transpose, AT, rand(Float32, 32, 32))
68
@test compare(transpose!, AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
9+
@test compare(transpose!, AT, rand(Float32, 1, 32), rand(Float32, 32))
10+
@test compare(transpose!, AT, rand(Float32, 32), rand(Float32, 1, 32))
711
@test compare((x,y)->copyto!(x, adjoint(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
812
@test compare((x,y)->copyto!(x, transpose(y)), AT, rand(Float32, 32, 32), rand(Float32, 32, 32))
913
@test compare(transpose!, AT, Array{Float32}(undef, 32, 32), rand(Float32, 32, 32))

0 commit comments

Comments
 (0)