Skip to content

Commit e1217c9

Browse files
authored
Specialize LinearAlgebra.mul! on Diagonal for GPUArrays (#416)
1 parent ec12883 commit e1217c9

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

src/host/linalg.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,36 @@ if VERSION < v"1.8-"
173173
return B
174174
end
175175
else
176+
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
177+
D::Diagonal{<:Any, <:AbstractGPUArray},
178+
A::AbstractGPUVecOrMat)
179+
dd = D.diag
180+
d = length(dd)
181+
m, n = size(A, 1), size(A, 2)
182+
m′, n′ = size(B, 1), size(B, 2)
183+
m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d"))
184+
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
185+
@. B = dd * A
186+
187+
B
188+
end
189+
190+
function LinearAlgebra.mul!(B::AbstractGPUVecOrMat,
191+
D::Diagonal{<:Any, <:AbstractGPUArray},
192+
A::AbstractGPUVecOrMat,
193+
α::Number,
194+
β::Number)
195+
dd = D.diag
196+
d = length(dd)
197+
m, n = size(A, 1), size(A, 2)
198+
m′, n′ = size(B, 1), size(B, 2)
199+
m == d || throw(DimensionMismatch("right hand side has $m rows but D is $d by $d"))
200+
(m, n) == (m′, n′) || throw(DimensionMismatch("expect output to be $m by $n, but got $m′ by $n′"))
201+
@. B = α * dd* A + β * B
202+
203+
B
204+
end
205+
176206
function LinearAlgebra.ldiv!(B::AbstractGPUVecOrMat,
177207
D::Diagonal{<:Any, <:AbstractGPUArray},
178208
A::AbstractGPUVecOrMat)

test/testsuite/linalg.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,23 @@
153153
@test_throws SingularException D \ B
154154
end
155155

156+
@testset "mul! + Diagonal" begin
157+
n = 128
158+
d = AT(rand(Float32, n))
159+
D = Diagonal(d)
160+
B = AT(rand(Float32, n, n))
161+
X = AT(zeros(Float32, n, n))
162+
Y = zeros(Float32, n, n)
163+
α = rand(Float32)
164+
β = rand(Float32)
165+
mul!(X, D, B)
166+
mul!(Y, Diagonal(collect(d)), collect(B))
167+
@test collect(X) Y
168+
mul!(X, D, B, α, β)
169+
mul!(Y, Diagonal(collect(d)), collect(B), α, β)
170+
@test collect(X) Y
171+
end
172+
156173
@testset "ldiv! + Diagonal" begin
157174
n = 128
158175
d = AT(rand(Float32, n))

0 commit comments

Comments
 (0)