Skip to content

Commit 504a495

Browse files
authored
Add symmetric support for mul! (#217)
1 parent c6f35aa commit 504a495

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

lib/cublas/linalg.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,64 @@ for (t, uploc, isunitc) in ((:LowerTriangular, 'U', 'N'),
363363
CUBLAS.trsm!('R', $uploc, 'C', $isunitc, one(T), parent(parent(B)), A)
364364
end
365365
end
366+
367+
# symmetric mul!
368+
# level 2
369+
@inline function LinearAlgebra.mul!(y::CuVector{T}, A::Hermitian{T,<:CuMatrix}, x::CuVector{T},
370+
α::Number, β::Number) where {T<:CublasReal}
371+
alpha, beta = promote(α, β, zero(T))
372+
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
373+
return CUBLAS.symv!(A.uplo, alpha, A.data, x, beta, y)
374+
else
375+
error("only supports BLAS type, got $T")
376+
end
377+
end
378+
379+
@inline function LinearAlgebra.mul!(y::CuVector{T}, A::Hermitian{T,<:CuMatrix}, x::CuVector{T},
380+
α::Number, β::Number) where {T<:CublasComplex}
381+
alpha, beta = promote(α, β, zero(T))
382+
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
383+
return CUBLAS.hemv!(A.uplo, alpha, A.data, x, beta, y)
384+
else
385+
error("only supports BLAS type, got $T")
386+
end
387+
end
388+
389+
# level 3
390+
391+
@inline function LinearAlgebra.mul!(C::CuMatrix{T}, A::Hermitian{T,<:CuMatrix}, B::CuMatrix{T},
392+
α::Number, β::Number) where {T<:CublasReal}
393+
alpha, beta = promote(α, β, zero(T))
394+
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
395+
return CUBLAS.symm!('L', A.uplo, alpha, A.data, B, beta, C)
396+
else
397+
error("only supports BLAS type, got $T")
398+
end
399+
end
400+
@inline function LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, B::Hermitian{T,<:CuMatrix},
401+
α::Number, β::Number) where {T<:CublasReal}
402+
alpha, beta = promote(α, β, zero(T))
403+
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
404+
return CUBLAS.symm!('R', B.uplo, alpha, B.data, A, beta, C)
405+
else
406+
error("only supports BLAS type, got $T")
407+
end
408+
end
409+
@inline function LinearAlgebra.mul!(C::CuMatrix{T}, A::Hermitian{T,<:CuMatrix}, B::CuMatrix{T},
410+
α::Number, β::Number) where {T<:CublasComplex}
411+
alpha, beta = promote(α, β, zero(T))
412+
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
413+
return CUBLAS.hemm!('L', A.uplo, alpha, A.data, B, beta, C)
414+
else
415+
error("only supports BLAS type, got $T")
416+
end
417+
end
418+
@inline function LinearAlgebra.mul!(C::CuMatrix{T}, A::CuMatrix{T}, B::Hermitian{T,<:CuMatrix},
419+
α::Number, β::Number) where {T<:CublasComplex}
420+
alpha, beta = promote(α, β, zero(T))
421+
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
422+
return CUBLAS.hemm!('R', B.uplo, alpha, B.data, A, beta, C)
423+
else
424+
error("only supports BLAS type, got $T")
425+
end
426+
end

test/cublas.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ end
135135
mul!(y, f(A), x, Ts(1), Ts(2))
136136
@test Array(dy) y
137137
end
138+
139+
@testset "hermitian" begin
140+
y, A, x = rand(elty, 5), Hermitian(rand(elty, 5, 5)), rand(elty, 5)
141+
dy, dA, dx = CuArray(y), Hermitian(CuArray(A)), CuArray(x)
142+
mul!(dy, dA, dx)
143+
mul!(y, A, x)
144+
@test Array(dy) y
145+
end
146+
138147
@testset "banded methods" begin
139148
# bands
140149
ku = 2
@@ -553,6 +562,15 @@ end
553562
mul!(C, f(A), g(B), Ts(1), Ts(2))
554563
@test Array(dC) C
555564
end
565+
566+
@testset "hermitian" begin
567+
C, A, B = rand(elty, 5, 5), Hermitian(rand(elty, 5, 5)), rand(elty, 5, 5)
568+
dC, dA, dB = CuArray(C), Hermitian(CuArray(A)), CuArray(B)
569+
mul!(dC, dA, dB)
570+
mul!(C, A, B)
571+
@test Array(dC) C
572+
end
573+
556574
A = rand(elty,m,k)
557575
B = rand(elty,k,n)
558576
Bbad = rand(elty,k+1,n+1)

0 commit comments

Comments
 (0)