Skip to content

Commit 44e5818

Browse files
authored
Add more BLAS tests (#601)
1 parent 5491a02 commit 44e5818

18 files changed

+281
-333
lines changed

.codecov.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
coverage:
22
ignore:
3+
- "src/*/lib*.jl"
34
- "src/device"
5+
- "docs/"
46
status:
57
patch:
68
default:

src/blas/highlevel.jl

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@ rocblas_size(t::Char, M::ROCVecOrMat) = (size(M, t=='N' ? 1 : 2), size(M, t=='N'
22

33
const ROCBLASArray{T<:ROCBLASFloat} = ROCArray{T}
44

5-
###########
65
#
76
# BLAS 1
87
#
9-
###########
108

119
LinearAlgebra.rmul!(x::ROCArray{<:ROCBLASFloat}, k::Number) =
1210
scal!(length(x), convert(eltype(x), k), x, 1)
@@ -84,11 +82,9 @@ function LinearAlgebra.reflect!(
8482
x, y
8583
end
8684

87-
############
8885
#
8986
# BLAS 2
9087
#
91-
############
9288

9389
if VERSION v"1.10-"
9490
# multiplication
@@ -145,36 +141,56 @@ else
145141
end
146142
end
147143

148-
#########
149144
# GEMV
150-
##########
151145

152-
function gemv_wrapper!(
153-
y::ROCVector{T}, tA::Char, A::ROCMatrix{T}, x::ROCVector{T},
154-
alpha = one(T), beta = zero(T),
155-
) where T <: ROCBLASFloat
156-
mA, nA = rocblas_size(tA, A)
157-
if nA != length(x)
158-
throw(DimensionMismatch("second dimension of A, $nA, does not match length of x, $(length(x))"))
159-
end
160-
if mA != length(y)
161-
throw(DimensionMismatch("first dimension of A, $mA, does not match length of y, $(length(y))"))
146+
function LinearAlgebra.generic_matvecmul!(
147+
Y::ROCVector, tA::AbstractChar, A::StridedROCMatrix, B::StridedROCVector,
148+
_add::MulAddMul,
149+
)
150+
mA, nA = tA == 'N' ? size(A) : reverse(size(A))
151+
152+
nA != length(B) && throw(DimensionMismatch(
153+
"second dimension of A, $nA, does not match length of B, $(length(B))"))
154+
mA != length(Y) && throw(DimensionMismatch(
155+
"first dimension of A, $mA, does not match length of Y, $(length(Y))"))
156+
157+
mA == 0 && return Y
158+
nA == 0 && return rmul!(Y, 0)
159+
160+
T = eltype(Y)
161+
alpha, beta = _add.alpha, _add.beta
162+
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
163+
α, β = T(alpha), T(beta)
164+
if T <: ROCBLASFloat && eltype(A) == eltype(B) == T
165+
if tA in ('N', 'T', 'C')
166+
return gemv!(tA, α, A, B, β, Y)
167+
elseif tA in ('S', 's')
168+
return symv!(tA == 'S' ? 'U' : 'L', α, A, B, β, Y)
169+
elseif tA in ('H', 'h')
170+
return hemv!(tA == 'H' ? 'U' : 'L', α, A, B, β, Y)
171+
end
172+
end
162173
end
163-
mA == 0 && return y
164-
nA == 0 && return rmul!(y, 0)
165-
gemv!(tA, alpha, A, x, beta, y)
174+
LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, MulAddMul(alpha, beta))
166175
end
167176

168-
LinearAlgebra.mul!(Y::ROCVector{T}, A::ROCMatrix{T}, B::ROCVector{T}) where T<:ROCBLASFloat = gemv_wrapper!(Y, 'N', A, B)
169-
LinearAlgebra.lmul!(Y::ROCVector{T}, A::LinearAlgebra.Transpose{<:Any, ROCMatrix{T}}, B::ROCVector{T}) where T<:ROCBLASFloat = gemv_wrapper!(Y, 'T', A.parent, B)
170-
LinearAlgebra.lmul!(Y::ROCVector{T}, A::LinearAlgebra.Adjoint{<:Any, ROCMatrix{T}}, B::ROCVector{T}) where T<:ROCBLASFloat = gemv_wrapper!(Y, 'T', A.parent, B)
171-
LinearAlgebra.lmul!(Y::ROCVector{T}, A::LinearAlgebra.Adjoint{<:Any, ROCMatrix{T}}, B::ROCVector{T}) where T<:ROCBLASComplex = gemv_wrapper!(Y, 'C', A.parent, B)
177+
if VERSION < v"1.10.0-DEV.1365"
178+
@inline LinearAlgebra.gemv!(
179+
Y::ROCVector, tA::AbstractChar, A::StridedROCMatrix,
180+
B::StridedROCVector, a::Number, b::Number,
181+
) = LinearAlgebra.generic_matvecmul!(Y, tA, A, B, MulAddMul(a, b))
182+
183+
# disambiguation with LinearAlgebra.jl
184+
@inline LinearAlgebra.gemv!(
185+
Y::ROCVector{T}, tA::AbstractChar, A::StridedROCMatrix{T},
186+
B::StridedROCVector{T}, a::Number, b::Number,
187+
) where T <: ROCBLASFloat =
188+
LinearAlgebra.generic_matvecmul!(Y, tA, A, B, MulAddMul(a, b))
189+
end
172190

173-
############
174191
#
175192
# BLAS 3
176193
#
177-
############
178194

179195
########
180196
# GEMM

src/blas/rocBLAS.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
module rocBLAS
22

33
using ..AMDGPU
4-
import AMDGPU: librocblas, AnyROCArray
4+
import AMDGPU: librocblas, AnyROCArray, StridedROCVector, StridedROCMatrix
55
import AMDGPU: HandleCache, HIP, library_state
66
import .HIP: HIPContext, HIPStream, hipStream_t, hipEvent_t
77

88
using LinearAlgebra
9-
using LinearAlgebra: AdjOrTrans
9+
using LinearAlgebra: AdjOrTrans, MulAddMul
1010
if VERSION v"1.10-"
1111
using LinearAlgebra: wrap, UpperOrLowerTriangular
1212
end

src/blas/util.jl

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,6 @@ const ROCBLASComplex = Union{ComplexF32, ComplexF64}
33
const ROCBLASFloat = Union{ROCBLASReal, ROCBLASComplex}
44
const ROCBLASFloatWithHalf = Union{Float16, ROCBLASFloat}
55

6-
# Utility functions
7-
8-
# convert Char {N,T,C} to rocblas_operation
9-
function rocblasop(trans::Char)
10-
trans == 'N' && return rocblas_operation_none
11-
trans == 'T' && return rocblas_operation_transpose
12-
trans == 'C' && return rocblas_operation_conjugate_transpose
13-
throw(ArgumentError("unknown rocblas operation $trans"))
14-
end
15-
166
function Base.convert(::Type{rocblas_operation}, trans::Char)
177
if trans == 'N'
188
return rocblas_operation_none
@@ -25,13 +15,6 @@ function Base.convert(::Type{rocblas_operation}, trans::Char)
2515
end
2616
end
2717

28-
# convert Char {U,L} to rocblas_fill
29-
function rocblasfill(uplo::Char)
30-
uplo == 'U' && return rocblas_fill_upper
31-
uplo == 'L' && return rocblas_fill_lower
32-
throw(ArgumentError("unknown rocblas fill mode $uplo"))
33-
end
34-
3518
function Base.convert(::Type{rocblas_fill}, uplo::Char)
3619
if uplo == 'U'
3720
return rocblas_fill_upper
@@ -42,13 +25,6 @@ function Base.convert(::Type{rocblas_fill}, uplo::Char)
4225
end
4326
end
4427

45-
# convert Char {U,N} to rocblas_diagonal
46-
function rocblasdiag(diag::Char)
47-
diag == 'U' && return rocblas_diagonal_unit
48-
diag == 'N' && return rocblas_diagonal_non_unit
49-
throw(ArgumentError("unknown rocblas diag mode $diag"))
50-
end
51-
5228
function Base.convert(::Type{rocblas_diagonal}, diag::Char)
5329
if diag == 'U'
5430
return rocblas_diagonal_unit
@@ -59,13 +35,6 @@ function Base.convert(::Type{rocblas_diagonal}, diag::Char)
5935
end
6036
end
6137

62-
# convert Char {L,R} to rocblas_side
63-
function rocblasside(side::Char)
64-
side == 'L' && return rocblas_side_left
65-
side == 'R' && return rocblas_side_right
66-
throw(ArgumentError("unknown rocblas side mode $side"))
67-
end
68-
6938
function Base.convert(::Type{rocblas_side}, side::Char)
7039
if side == 'L'
7140
return rocblas_side_left

src/dnn/MIOpen.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import AMDGPU: libMIOpen_path
88
import AMDGPU.Runtime.Mem # TODO remove?
99
import .HIP: hipStream_t
1010

11-
include("low_level.jl")
11+
include("libMIOpen.jl")
1212

1313
const STATUS_DESCRIPTORS = Dict(
1414
miopenStatusSuccess => "Success",
File renamed without changes.

src/hsa/HSA.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module HSA
22

3-
include("LibHSARuntime.jl")
3+
include("libHSA.jl")
44

55
# Forward prefixed names
66
hsa_names = map(string, names(LibHSARuntime))
File renamed without changes.

test/core_tests.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,8 @@ include("rocarray/base.jl")
8686
include("rocarray/broadcast.jl")
8787

8888
const IS_NAVI3 = AMDGPU.device().gcn_arch in ("gfx1100", "gfx1101", "gfx1102", "gfx1103")
89-
90-
# TODO rework, hangs on Navi 3
9189
if !IS_NAVI3
92-
include("tls.jl")
90+
include("tls.jl") # TODO hangs on Navi 3
9391
end
9492

9593
end

test/hip_core_tests.jl

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testitem "hip - core" begin
1+
@testitem "hip - core" setup=[TSCore] begin
22

33
using Test
44
using LinearAlgebra
@@ -32,14 +32,4 @@ if length(AMDGPU.devices()) > 1
3232
end
3333
end
3434

35-
if AMDGPU.functional(:rocblas)
36-
include("rocarray/blas.jl")
37-
end
38-
if AMDGPU.functional(:MIOpen)
39-
include("dnn/miopen.jl")
40-
end
41-
if AMDGPU.functional(:rocrand)
42-
include("rocarray/random.jl")
43-
end
44-
4535
end

0 commit comments

Comments
 (0)