Skip to content

Commit effeef9

Browse files
authored
Add a GPU version of copytrito! (#504)
1 parent 6becb4f commit effeef9

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

src/host/linalg.jl

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,16 @@ function Base.copyto!(A::Array{T,N}, B::Transpose{T, <:AbstractGPUArray{T,N}}) w
8080
copyto!(A, Transpose(Array(parent(B))))
8181
end
8282

83-
8483
## copy upper triangle to lower and vice versa
8584

86-
function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, conjugate::Bool=false) where T
85+
function LinearAlgebra.copytri!(A::AbstractGPUMatrix, uplo::AbstractChar, conjugate::Bool=false)
8786
n = LinearAlgebra.checksquare(A)
8887
if uplo == 'U' && conjugate
8988
gpu_call(A) do ctx, _A
9089
I = @cartesianidx _A
9190
i, j = Tuple(I)
9291
if j > i
93-
_A[j,i] = conj(_A[i,j])
92+
@inbounds _A[j,i] = conj(_A[i,j])
9493
end
9594
return
9695
end
@@ -99,7 +98,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
9998
I = @cartesianidx _A
10099
i, j = Tuple(I)
101100
if j > i
102-
_A[j,i] = _A[i,j]
101+
@inbounds _A[j,i] = _A[i,j]
103102
end
104103
return
105104
end
@@ -108,7 +107,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
108107
I = @cartesianidx _A
109108
i, j = Tuple(I)
110109
if j > i
111-
_A[i,j] = conj(_A[j,i])
110+
@inbounds _A[i,j] = conj(_A[j,i])
112111
end
113112
return
114113
end
@@ -117,7 +116,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
117116
I = @cartesianidx _A
118117
i, j = Tuple(I)
119118
if j > i
120-
_A[i,j] = _A[j,i]
119+
@inbounds _A[i,j] = _A[j,i]
121120
end
122121
return
123122
end
@@ -127,6 +126,36 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
127126
A
128127
end
129128

129+
## copy a triangular part of a matrix to another matrix
130+
131+
if isdefined(LinearAlgebra, :copytrito!)
132+
function LinearAlgebra.copytrito!(B::AbstractGPUMatrix, A::AbstractGPUMatrix, uplo::AbstractChar)
133+
LinearAlgebra.BLAS.chkuplo(uplo)
134+
m,n = size(A)
135+
m1,n1 = size(B)
136+
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
137+
if uplo == 'U'
138+
gpu_call(A, B) do ctx, _A, _B
139+
I = @cartesianidx _A
140+
i, j = Tuple(I)
141+
if j >= i
142+
@inbounds _B[i,j] = _A[i,j]
143+
end
144+
return
145+
end
146+
else # uplo == 'L'
147+
gpu_call(A, B) do ctx, _A, _B
148+
I = @cartesianidx _A
149+
i, j = Tuple(I)
150+
if j <= i
151+
@inbounds _B[i,j] = _A[i,j]
152+
end
153+
return
154+
end
155+
end
156+
return B
157+
end
158+
end
130159

131160
## triangular
132161

@@ -146,7 +175,7 @@ function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
146175
I = @cartesianidx _A
147176
i, j = Tuple(I)
148177
if i < j - _d
149-
_A[i, j] = 0
178+
@inbounds _A[i, j] = zero(T)
150179
end
151180
return
152181
end
@@ -158,7 +187,7 @@ function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
158187
I = @cartesianidx _A
159188
i, j = Tuple(I)
160189
if j < i + _d
161-
_A[i, j] = 0
190+
@inbounds _A[i, j] = zero(T)
162191
end
163192
return
164193
end

test/testsuite/linalg.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,17 @@
7777
end
7878
end
7979

80+
if isdefined(LinearAlgebra, :copytrito!)
81+
@testset "copytrito!" begin
82+
@testset for T in eltypes, uplo in ('L', 'U')
83+
n = 16
84+
A = rand(T,n,n)
85+
B = zeros(T,n,n)
86+
@test compare(copytrito!, AT, B, A, uplo)
87+
end
88+
end
89+
end
90+
8091
@testset "copyto! for triangular" begin
8192
for TR in (UpperTriangular, LowerTriangular)
8293
@test compare(transpose!, AT, Array{Float32}(undef, 128, 32), rand(Float32, 32, 128))

0 commit comments

Comments
 (0)