@@ -110,7 +110,7 @@ function gather!(X::CuSparseVector, Y::CuVector, index::SparseChar)
110
110
end
111
111
112
112
function mv! (transa:: SparseChar , alpha:: Number , A:: Union{CuSparseMatrixBSR{T},CuSparseMatrixCSR{T}} ,
113
- X:: DenseCuVector{T} , beta:: Number , Y:: DenseCuVector{T} , index:: SparseChar ) where {T}
113
+ X:: DenseCuVector{T} , beta:: Number , Y:: DenseCuVector{T} , index:: SparseChar , algo :: cusparseSpMVAlg_t = CUSPARSE_MV_ALG_DEFAULT ) where {T}
114
114
m,n = size (A)
115
115
116
116
if transa == ' N'
@@ -127,18 +127,18 @@ function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixBSR{T},Cu
127
127
function bufferSize ()
128
128
out = Ref {Csize_t} ()
129
129
cusparseSpMV_bufferSize (handle (), transa, Ref {compute_type} (alpha), descA, descX, Ref {compute_type} (beta),
130
- descY, compute_type, CUSPARSE_SPMV_ALG_DEFAULT , out)
130
+ descY, compute_type, algo , out)
131
131
return out[]
132
132
end
133
133
with_workspace (bufferSize) do buffer
134
134
cusparseSpMV (handle (), transa, Ref {compute_type} (alpha), descA, descX, Ref {compute_type} (beta),
135
- descY, compute_type, CUSPARSE_SPMV_ALG_DEFAULT , buffer)
135
+ descY, compute_type, algo , buffer)
136
136
end
137
137
Y
138
138
end
139
139
140
140
function mv! (transa:: SparseChar , alpha:: Number , A:: CuSparseMatrixCSC{T} , X:: DenseCuVector{T} ,
141
- beta:: Number , Y:: DenseCuVector{T} , index:: SparseChar ) where {T}
141
+ beta:: Number , Y:: DenseCuVector{T} , index:: SparseChar , algo :: cusparseSpMVAlg_t = CUSPARSE_MV_ALG_DEFAULT ) where {T}
142
142
ctransa = ' N'
143
143
if transa == ' N'
144
144
ctransa = ' T'
@@ -163,19 +163,19 @@ function mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrixCSC{T}, X::Dens
163
163
function bufferSize ()
164
164
out = Ref {Csize_t} ()
165
165
cusparseSpMV_bufferSize (handle (), ctransa, Ref {compute_type} (alpha), descA, descX, Ref {compute_type} (beta),
166
- descY, compute_type, CUSPARSE_SPMV_ALG_DEFAULT , out)
166
+ descY, compute_type, algo , out)
167
167
return out[]
168
168
end
169
169
with_workspace (bufferSize) do buffer
170
170
cusparseSpMV (handle (), ctransa, Ref {compute_type} (alpha), descA, descX, Ref {compute_type} (beta),
171
- descY, compute_type, CUSPARSE_SPMV_ALG_DEFAULT , buffer)
171
+ descY, compute_type, algo , buffer)
172
172
end
173
173
174
174
return Y
175
175
end
176
176
177
177
function mm! (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: CuSparseMatrixCSR{T} ,
178
- B:: DenseCuMatrix{T} , beta:: Number , C:: DenseCuMatrix{T} , index:: SparseChar ) where {T}
178
+ B:: DenseCuMatrix{T} , beta:: Number , C:: DenseCuMatrix{T} , index:: SparseChar , algo :: cusparseSpMMAlg_t = CUSPARSE_MM_ALG_DEFAULT ) where {T}
179
179
m,k = size (A)
180
180
n = size (C)[2 ]
181
181
@@ -197,20 +197,20 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseM
197
197
out = Ref {Csize_t} ()
198
198
cusparseSpMM_bufferSize (
199
199
handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (beta),
200
- descC, T, CUSPARSE_SPMM_ALG_DEFAULT , out)
200
+ descC, T, algo , out)
201
201
return out[]
202
202
end
203
203
with_workspace (bufferSize) do buffer
204
204
cusparseSpMM (
205
205
handle (), transa, transb, Ref {T} (alpha), descA, descB, Ref {T} (beta),
206
- descC, T, CUSPARSE_SPMM_ALG_DEFAULT , buffer)
206
+ descC, T, algo , buffer)
207
207
end
208
208
209
209
return C
210
210
end
211
211
212
212
function mm! (transa:: SparseChar , transb:: SparseChar , alpha:: Number , A:: CuSparseMatrixCSC{T} ,
213
- B:: DenseCuMatrix{T} , beta:: Number , C:: DenseCuMatrix{T} , index:: SparseChar ) where {T}
213
+ B:: DenseCuMatrix{T} , beta:: Number , C:: DenseCuMatrix{T} , index:: SparseChar , algo :: cusparseSpMMAlg_t = CUSPARSE_MM_ALG_DEFAULT ) where {T}
214
214
ctransa = ' N'
215
215
if transa == ' N'
216
216
ctransa = ' T'
@@ -240,13 +240,13 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseM
240
240
out = Ref {Csize_t} ()
241
241
cusparseSpMM_bufferSize (
242
242
handle (), ctransa, transb, Ref {T} (alpha), descA, descB, Ref {T} (beta),
243
- descC, T, CUSPARSE_SPMM_ALG_DEFAULT , out)
243
+ descC, T, algo , out)
244
244
return out[]
245
245
end
246
246
with_workspace (bufferSize) do buffer
247
247
cusparseSpMM (
248
248
handle (), ctransa, transb, Ref {T} (alpha), descA, descB, Ref {T} (beta),
249
- descC, T, CUSPARSE_SPMM_ALG_DEFAULT , buffer)
249
+ descC, T, algo , buffer)
250
250
end
251
251
252
252
return C
0 commit comments