Skip to content

Commit a3369df

Browse files
authored
[WIP] Speed up dense-sparse matmul (#38876)
* Speed up dense-sparse matmul * add one at-simd, minor edits * improve A_mul_Bq for dense-sparse * revert ineffective changes * shift at-inbounds annotation
1 parent 3d1598e commit a3369df

File tree

1 file changed

+65
-93
lines changed

1 file changed

+65
-93
lines changed

stdlib/SparseArrays/src/linalg.jl

Lines changed: 65 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVe
3434
if β != 1
3535
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
3636
end
37-
for k = 1:size(C, 2)
38-
@inbounds for col = 1:size(A, 2)
37+
for k in 1:size(C, 2)
38+
@inbounds for col in 1:size(A, 2)
3939
αxj = B[col,k] * α
40-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
40+
for j in nzrange(A, col)
4141
C[rv[j], k] += nzv[j]*αxj
4242
end
4343
end
@@ -49,67 +49,38 @@ end
4949
*(A::SparseMatrixCSCUnion{TA}, B::AdjOrTransStridedOrTriangularMatrix{Tx}) where {TA,Tx} =
5050
(T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B, true, false))
5151

52-
function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
53-
A = adjA.parent
54-
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
55-
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
56-
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
57-
nzv = nonzeros(A)
58-
rv = rowvals(A)
59-
if β != 1
60-
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
61-
end
62-
for k = 1:size(C, 2)
63-
@inbounds for col = 1:size(A, 2)
64-
tmp = zero(eltype(C))
65-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
66-
tmp += adjoint(nzv[j])*B[rv[j],k]
52+
for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
53+
@eval function mul!(C::StridedVecOrMat, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
54+
A = xA.parent
55+
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
56+
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
57+
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
58+
nzv = nonzeros(A)
59+
rv = rowvals(A)
60+
if β != 1
61+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
62+
end
63+
for k in 1:size(C, 2)
64+
@inbounds for col in 1:size(A, 2)
65+
tmp = zero(eltype(C))
66+
for j in nzrange(A, col)
67+
tmp += $t(nzv[j])*B[rv[j],k]
68+
end
69+
C[col,k] += tmp * α
6770
end
68-
C[col,k] += tmp * α
6971
end
72+
C
7073
end
71-
C
7274
end
7375
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} =
7476
(T = promote_op(matprod, eltype(adjA), Tx); mul!(similar(x, T, size(adjA, 1)), adjA, x, true, false))
7577
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix) =
7678
(T = promote_op(matprod, eltype(adjA), eltype(B)); mul!(similar(B, T, (size(adjA, 1), size(B, 2))), adjA, B, true, false))
77-
78-
function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
79-
A = transA.parent
80-
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
81-
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
82-
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
83-
nzv = nonzeros(A)
84-
rv = rowvals(A)
85-
if β != 1
86-
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
87-
end
88-
for k = 1:size(C, 2)
89-
@inbounds for col = 1:size(A, 2)
90-
tmp = zero(eltype(C))
91-
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
92-
tmp += transpose(nzv[j])*B[rv[j],k]
93-
end
94-
C[col,k] += tmp * α
95-
end
96-
end
97-
C
98-
end
9979
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} =
10080
(T = promote_op(matprod, eltype(transA), Tx); mul!(similar(x, T, size(transA, 1)), transA, x, true, false))
10181
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix) =
10282
(T = promote_op(matprod, eltype(transA), eltype(B)); mul!(similar(B, T, (size(transA, 1), size(B, 2))), transA, B, true, false))
10383

104-
# For compatibility with dense multiplication API. Should be deleted when dense multiplication
105-
# API is updated to follow BLAS API.
106-
mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}) =
107-
mul!(C, A, B, true, false)
108-
mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}) =
109-
mul!(C, adjA, B, true, false)
110-
mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}) =
111-
mul!(C, transA, B, true, false)
112-
11384
function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::AbstractSparseMatrixCSC, α::Number, β::Number)
11485
mX, nX = size(X)
11586
nX == size(A, 1) || throw(DimensionMismatch())
@@ -120,49 +91,50 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::Abs
12091
if β != 1
12192
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
12293
end
123-
@inbounds for multivec_row=1:mX, col = 1:size(A, 2), k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
124-
C[multivec_row, col] += α * X[multivec_row, rv[k]] * nzv[k] # perhaps suboptimal position of α?
94+
if X isa StridedOrTriangularMatrix
95+
@inbounds for col in 1:size(A, 2), k in nzrange(A, col)
96+
Aiα = nzv[k] * α
97+
rvk = rv[k]
98+
@simd for multivec_row in 1:mX
99+
C[multivec_row, col] += X[multivec_row, rvk] * Aiα
100+
end
101+
end
102+
else # X isa Adjoint or Transpose
103+
for multivec_row in 1:mX, col in 1:size(A, 2)
104+
@inbounds for k in nzrange(A, col)
105+
C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α
106+
end
107+
end
125108
end
126109
C
127110
end
128111
*(X::AdjOrTransStridedOrTriangularMatrix, A::SparseMatrixCSCUnion{TvA}) where {TvA} =
129112
(T = promote_op(matprod, eltype(X), TvA); mul!(similar(X, T, (size(X, 1), size(A, 2))), X, A, true, false))
130113

131-
function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
132-
A = adjA.parent
133-
mX, nX = size(X)
134-
nX == size(A, 2) || throw(DimensionMismatch())
135-
mX == size(C, 1) || throw(DimensionMismatch())
136-
size(A, 1) == size(C, 2) || throw(DimensionMismatch())
137-
rv = rowvals(A)
138-
nzv = nonzeros(A)
139-
if β != 1
140-
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
141-
end
142-
@inbounds for col = 1:size(A, 2), k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1), multivec_col=1:mX
143-
C[multivec_col, rv[k]] += α * X[multivec_col, col] * adjoint(nzv[k]) # perhaps suboptimal position of α?
114+
for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
115+
@eval function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
116+
A = xA.parent
117+
mX, nX = size(X)
118+
nX == size(A, 2) || throw(DimensionMismatch())
119+
mX == size(C, 1) || throw(DimensionMismatch())
120+
size(A, 1) == size(C, 2) || throw(DimensionMismatch())
121+
rv = rowvals(A)
122+
nzv = nonzeros(A)
123+
if β != 1
124+
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
125+
end
126+
@inbounds for col in 1:size(A, 2), k in nzrange(A, col)
127+
Aiα = $t(nzv[k]) * α
128+
rvk = rv[k]
129+
@simd for multivec_col in 1:mX
130+
C[multivec_col, rvk] += X[multivec_col, col] * Aiα
131+
end
132+
end
133+
C
144134
end
145-
C
146135
end
147136
*(X::AdjOrTransStridedOrTriangularMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) =
148137
(T = promote_op(matprod, eltype(X), eltype(adjA)); mul!(similar(X, T, (size(X, 1), size(adjA, 2))), X, adjA, true, false))
149-
150-
function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
151-
A = transA.parent
152-
mX, nX = size(X)
153-
nX == size(A, 2) || throw(DimensionMismatch())
154-
mX == size(C, 1) || throw(DimensionMismatch())
155-
size(A, 1) == size(C, 2) || throw(DimensionMismatch())
156-
rv = rowvals(A)
157-
nzv = nonzeros(A)
158-
if β != 1
159-
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
160-
end
161-
@inbounds for col = 1:size(A, 2), k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1), multivec_col=1:mX
162-
C[multivec_col, rv[k]] += α * X[multivec_col, col] * transpose(nzv[k]) # perhaps suboptimal position of α?
163-
end
164-
C
165-
end
166138
*(X::AdjOrTransStridedOrTriangularMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}) =
167139
(T = promote_op(matprod, eltype(X), eltype(transA)); mul!(similar(X, T, (size(X, 1), size(transA, 2))), X, transA, true, false))
168140

@@ -896,7 +868,7 @@ function ldiv!(D::Diagonal{T}, A::AbstractSparseMatrixCSC{T}) where {T}
896868
for i=1:length(b)
897869
iszero(b[i]) && throw(SingularException(i))
898870
end
899-
@inbounds for col = 1:size(A, 2), p = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
871+
@inbounds for col in 1:size(A, 2), p in nzrange(A, col)
900872
nonz[p] = b[Arowval[p]] \ nonz[p]
901873
end
902874
A
@@ -916,7 +888,7 @@ function triu(S::AbstractSparseMatrixCSC{Tv,Ti}, k::Integer=0) where {Tv,Ti}
916888
colptr[col] = 1
917889
end
918890
for col = max(k+1,1) : n
919-
for c1 = getcolptr(S)[col] : getcolptr(S)[col+1]-1
891+
for c1 in nzrange(S, col)
920892
rowvals(S)[c1] > col - k && break
921893
nnz += 1
922894
end
@@ -927,7 +899,7 @@ function triu(S::AbstractSparseMatrixCSC{Tv,Ti}, k::Integer=0) where {Tv,Ti}
927899
A = SparseMatrixCSC(m, n, colptr, rowval, nzval)
928900
for col = max(k+1,1) : n
929901
c1 = getcolptr(S)[col]
930-
for c2 = getcolptr(A)[col] : getcolptr(A)[col+1]-1
902+
for c2 in nzrange(A, col)
931903
rowvals(A)[c2] = rowvals(S)[c1]
932904
nonzeros(A)[c2] = nonzeros(S)[c1]
933905
c1 += 1
@@ -981,7 +953,7 @@ function sparse_diff1(S::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
981953
for col = 1 : n
982954
last_row = 0
983955
last_val = 0
984-
for k = getcolptr(S)[col] : getcolptr(S)[col+1]-1
956+
for k in nzrange(S, col)
985957
row = rowvals(S)[k]
986958
val = nonzeros(S)[k]
987959
if row > 1
@@ -1124,7 +1096,7 @@ function opnorm(A::AbstractSparseMatrixCSC, p::Real=2)
11241096
nA::Tsum = 0
11251097
for j=1:n
11261098
colSum::Tsum = 0
1127-
for i = getcolptr(A)[j]:getcolptr(A)[j+1]-1
1099+
for i in nzrange(A, j)
11281100
colSum += abs(nonzeros(A)[i])
11291101
end
11301102
nA = max(nA, colSum)
@@ -1469,7 +1441,7 @@ function mul!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, D::Diagona
14691441
Cnzval = nonzeros(C)
14701442
Anzval = nonzeros(A)
14711443
resize!(Cnzval, length(Anzval))
1472-
for col = 1:n, p = getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
1444+
for col in 1:n, p in nzrange(A, col)
14731445
@inbounds Cnzval[p] = Anzval[p] * b[col]
14741446
end
14751447
C
@@ -1484,7 +1456,7 @@ function mul!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCS
14841456
Anzval = nonzeros(A)
14851457
Arowval = rowvals(A)
14861458
resize!(Cnzval, length(Anzval))
1487-
for col = 1:n, p = getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
1459+
for col in 1:n, p in nzrange(A, col)
14881460
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
14891461
end
14901462
C
@@ -1520,7 +1492,7 @@ function rmul!(A::AbstractSparseMatrixCSC, D::Diagonal)
15201492
m, n = size(A)
15211493
(n == size(D, 1)) || throw(DimensionMismatch())
15221494
Anzval = nonzeros(A)
1523-
@inbounds for col = 1:n, p = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
1495+
@inbounds for col in 1:n, p in nzrange(A, col)
15241496
Anzval[p] = Anzval[p] * D.diag[col]
15251497
end
15261498
return A
@@ -1531,7 +1503,7 @@ function lmul!(D::Diagonal, A::AbstractSparseMatrixCSC)
15311503
(m == size(D, 2)) || throw(DimensionMismatch())
15321504
Anzval = nonzeros(A)
15331505
Arowval = rowvals(A)
1534-
@inbounds for col = 1:n, p = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
1506+
@inbounds for col in 1:n, p in nzrange(A, col)
15351507
Anzval[p] = D.diag[Arowval[p]] * Anzval[p]
15361508
end
15371509
return A

0 commit comments

Comments
 (0)