1
1
rocblas_size (t:: Char , M:: ROCVecOrMat ) = (size (M, t== ' N' ? 1 : 2 ), size (M, t== ' N' ? 2 : 1 ))
2
2
3
- const ROCBLASArray{T<: ROCBLASFloat } = ROCArray{T}
4
-
5
3
#
6
4
# BLAS 1
7
5
#
@@ -19,30 +17,36 @@ LinearAlgebra.rmul!(x::ROCArray{<:ROCBLASComplex}, k::Number) =
19
17
LinearAlgebra. rmul! (x:: ROCArray{<:ROCBLASFloat} , k:: Real ) =
20
18
invoke (rmul!, Tuple{typeof (x), Number}, x, k)
21
19
22
- function LinearAlgebra. BLAS. dot (DX:: ROCArray{T} , DY:: ROCArray{T} ) where T <: Union{Float16, Float32, Float64}
20
+ function LinearAlgebra. dot (
21
+ DX:: StridedROCArray{T} , DY:: StridedROCArray{T} ,
22
+ ) where T <: Union{Float16, Float32, Float64}
23
23
n = length (DX)
24
- n== length (DY) || throw (DimensionMismatch (" dot product arguments have lengths $(length (DX)) and $(length (DY)) " ))
25
- dot (n, DX, 1 , DY, 1 )
24
+ n == length (DY) || throw (DimensionMismatch (
25
+ " dot product arguments have lengths $(length (DX)) and $(length (DY)) " ))
26
+ dot (n, DX, stride (DX, 1 ), DY, stride (DY, 1 ))
26
27
end
27
28
28
- function LinearAlgebra. BLAS. dotc (DX:: ROCArray{T} , DY:: ROCArray{T} ) where T <: ROCBLASComplex
29
+ function LinearAlgebra. dot (
30
+ DX:: StridedROCArray{T} , DY:: StridedROCArray{T} ,
31
+ ) where T <: ROCBLASComplex
29
32
n = length (DX)
30
- n== length (DY) || throw (DimensionMismatch (" dot product arguments have lengths $(length (DX)) and $(length (DY)) " ))
31
- dotc (n, DX, 1 , DY, 1 )
32
- end
33
-
34
- function LinearAlgebra. BLAS. dot (DX:: ROCArray{T} , DY:: ROCArray{T} ) where T <: ROCBLASComplex
35
- dotc (DX, DY)
33
+ n == length (DY) || throw (DimensionMismatch (
34
+ " dot product arguments have lengths $(length (DX)) and $(length (DY)) " ))
35
+ dotc (n, DX, stride (DX, 1 ), DY, stride (DY, 1 ))
36
36
end
37
37
38
- function LinearAlgebra. BLAS. dotu (DX:: ROCArray{T} , DY:: ROCArray{T} ) where T <: ROCBLASComplex
39
- n = length (DX)
40
- n== length (DY) || throw (DimensionMismatch (" dot product arguments have lengths $(length (DX)) and $(length (DY)) " ))
41
- dotu (n, DX, 1 , DY, 1 )
38
+ function LinearAlgebra.:(* )(
39
+ transx:: Transpose{<:Any, <:StridedROCVector{T}} , y:: StridedROCVector{T} ,
40
+ ) where T <: Union{ComplexF16, ROCBLASComplex}
41
+ x = transx. parent
42
+ n = length (x)
43
+ n == length (y) || throw (DimensionMismatch (
44
+ " dot product arguments have lengths $(length (x)) and $(length (y)) " ))
45
+ dotu (n, x, stride (x, 1 ), y, stride (y, 1 ))
42
46
end
43
47
44
- LinearAlgebra. norm (x:: ROCArray{T} ) where T <: ROCBLASFloat = nrm2 (length (x), x, 1 )
45
- LinearAlgebra. BLAS. asum (x:: ROCBLASArray ) = asum (length (x), x, 1 )
48
+ LinearAlgebra. norm (x:: ROCArray{T} ) where T <: ROCBLASFloat = nrm2 (length (x), x, stride (x, 1 ) )
49
+ LinearAlgebra. BLAS. asum (x:: ROCArray{T} ) where T <: ROCBLASFloat = asum (length (x), x, stride (x, 1 ) )
46
50
47
51
function LinearAlgebra. axpy! (
48
52
alpha:: Number , x:: ROCArray{T} , y:: ROCArray{T} ,
0 commit comments