Skip to content

Commit 9ba26b8

Browse files
authored
Update LA.dot wrappers (#647)
1 parent c5f6339 commit 9ba26b8

File tree

2 files changed

+24
-19
lines changed

2 files changed

+24
-19
lines changed

src/blas/highlevel.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
rocblas_size(t::Char, M::ROCVecOrMat) = (size(M, t=='N' ? 1 : 2), size(M, t=='N' ? 2 : 1))
22

3-
const ROCBLASArray{T<:ROCBLASFloat} = ROCArray{T}
4-
53
#
64
# BLAS 1
75
#
@@ -19,30 +17,36 @@ LinearAlgebra.rmul!(x::ROCArray{<:ROCBLASComplex}, k::Number) =
1917
LinearAlgebra.rmul!(x::ROCArray{<:ROCBLASFloat}, k::Real) =
2018
invoke(rmul!, Tuple{typeof(x), Number}, x, k)
2119

22-
function LinearAlgebra.BLAS.dot(DX::ROCArray{T}, DY::ROCArray{T}) where T <: Union{Float16, Float32, Float64}
20+
function LinearAlgebra.dot(
21+
DX::StridedROCArray{T}, DY::StridedROCArray{T},
22+
) where T <: Union{Float16, Float32, Float64}
2323
n = length(DX)
24-
n==length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
25-
dot(n, DX, 1, DY, 1)
24+
n == length(DY) || throw(DimensionMismatch(
25+
"dot product arguments have lengths $(length(DX)) and $(length(DY))"))
26+
dot(n, DX, stride(DX, 1), DY, stride(DY, 1))
2627
end
2728

28-
function LinearAlgebra.BLAS.dotc(DX::ROCArray{T}, DY::ROCArray{T}) where T <: ROCBLASComplex
29+
function LinearAlgebra.dot(
30+
DX::StridedROCArray{T}, DY::StridedROCArray{T},
31+
) where T <: ROCBLASComplex
2932
n = length(DX)
30-
n==length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
31-
dotc(n, DX, 1, DY, 1)
32-
end
33-
34-
function LinearAlgebra.BLAS.dot(DX::ROCArray{T}, DY::ROCArray{T}) where T <: ROCBLASComplex
35-
dotc(DX, DY)
33+
n == length(DY) || throw(DimensionMismatch(
34+
"dot product arguments have lengths $(length(DX)) and $(length(DY))"))
35+
dotc(n, DX, stride(DX, 1), DY, stride(DY, 1))
3636
end
3737

38-
function LinearAlgebra.BLAS.dotu(DX::ROCArray{T}, DY::ROCArray{T}) where T <: ROCBLASComplex
39-
n = length(DX)
40-
n==length(DY) || throw(DimensionMismatch("dot product arguments have lengths $(length(DX)) and $(length(DY))"))
41-
dotu(n, DX, 1, DY, 1)
38+
function LinearAlgebra.:(*)(
39+
transx::Transpose{<:Any, <:StridedROCVector{T}}, y::StridedROCVector{T},
40+
) where T <: Union{ComplexF16, ROCBLASComplex}
41+
x = transx.parent
42+
n = length(x)
43+
n == length(y) || throw(DimensionMismatch(
44+
"dot product arguments have lengths $(length(x)) and $(length(y))"))
45+
dotu(n, x, stride(x, 1), y, stride(y, 1))
4246
end
4347

44-
LinearAlgebra.norm(x::ROCArray{T}) where T <: ROCBLASFloat = nrm2(length(x), x, 1)
45-
LinearAlgebra.BLAS.asum(x::ROCBLASArray) = asum(length(x), x, 1)
48+
LinearAlgebra.norm(x::ROCArray{T}) where T <: ROCBLASFloat = nrm2(length(x), x, stride(x, 1))
49+
LinearAlgebra.BLAS.asum(x::ROCArray{T}) where T <: ROCBLASFloat = asum(length(x), x, stride(x, 1))
4650

4751
function LinearAlgebra.axpy!(
4852
alpha::Number, x::ROCArray{T}, y::ROCArray{T},

src/memory.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ function maybe_collect(; blocking::Bool = false)
177177
# And even more if the pressure is high.
178178
pressure > 0.6 && (max_gc_rate *= 2;)
179179
pressure > 0.8 && (max_gc_rate *= 2;)
180-
gc_rate > max_gc_rate && return
180+
# Always try to collect if pressure ≥ 0.9.
181+
gc_rate > max_gc_rate && pressure < 0.9 && return
181182

182183
Base.@atomic stats.last_time = current_time
183184

0 commit comments

Comments
 (0)