Skip to content

Commit 6dacd70

Browse files
lcwmasonamccallum
andauthored
Allow StaticArray eltype in matmat{vec,mul} (#1954)
Here we avoid promotion since it is not defined between scalar numbers and static array types. Co-authored-by: Mason McCallum <masonamccallum@gmail.com>
1 parent b47ecf1 commit 6dacd70

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

lib/cublas/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function LinearAlgebra.generic_matvecmul!(Y::CuVector, tA::AbstractChar, A::Stri
190190
end
191191

192192
T = eltype(Y)
193-
alpha, beta = promote(_add.alpha, _add.beta, zero(T))
193+
alpha, beta = _add.alpha, _add.beta
194194
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
195195
if T <: CublasFloat && eltype(A) == eltype(B) == T
196196
if tA in ('N', 'T', 'C')
@@ -273,7 +273,7 @@ end
273273

274274
function LinearAlgebra.generic_matmatmul!(C::CuVecOrMat, tA, tB, A::StridedCuVecOrMat, B::StridedCuVecOrMat, _add::MulAddMul)
275275
T = eltype(C)
276-
alpha, beta = promote(_add.alpha, _add.beta, zero(T))
276+
alpha, beta = _add.alpha, _add.beta
277277
mA, nA = size(A, tA == 'N' ? 1 : 2), size(A, tA == 'N' ? 2 : 1)
278278
mB, nB = size(B, tB == 'N' ? 1 : 2), size(B, tB == 'N' ? 2 : 1)
279279

test/libraries/cublas.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using CUDA.CUBLAS: band, bandex
44
using LinearAlgebra
55

66
using BFloat16s
7+
using StaticArrays
78

89
@test CUBLAS.version() isa VersionNumber
910
@test CUBLAS.version().major == CUBLAS.cublasGetProperty(CUDA.MAJOR_VERSION)
@@ -606,6 +607,14 @@ end
606607
W*b
607608
end
608609
end
610+
611+
@testset "StaticArray eltype" begin
612+
A = CuArray(rand(SVector{2, Float64}, 3, 3))
613+
B = CuArray(rand(Float64, 3, 1))
614+
C = A * B
615+
hC = Array(A) * Array(B)
616+
@test Array(C) hC
617+
end
609618
end
610619

611620
############################################################################################
@@ -2015,4 +2024,12 @@ end
20152024
@test C h_C
20162025
end
20172026
end # extensions
2027+
2028+
@testset "StaticArray eltype" begin
2029+
A = CuArray(rand(SVector{2, Float32}, 3, 3))
2030+
B = CuArray(rand(Float32, 3, 3))
2031+
C = A * B
2032+
hC = Array(A) * Array(B)
2033+
@test Array(C) hC
2034+
end
20182035
end # elty

0 commit comments

Comments
 (0)