Skip to content

Commit b6a4efb

Browse files
authored
[CUSOLVER] Add a structure CuSolverParameters fro the generic API (#2188)
1 parent 3122ba8 commit b6a4efb

File tree

1 file changed

+42
-48
lines changed

1 file changed

+42
-48
lines changed

lib/cusolver/dense_generic.jl

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,38 @@
1+
mutable struct CuSolverParameters
2+
parameters::cusolverDnParams_t
3+
4+
function CuSolverParameters()
5+
parameters_ref = Ref{cusolverDnParams_t}()
6+
cusolverDnCreateParams(parameters_ref)
7+
obj = new(parameters_ref[])
8+
finalizer(cusolverDnDestroyParams, obj)
9+
obj
10+
end
11+
end
12+
13+
Base.unsafe_convert(::Type{cusolverDnParams_t}, params::CuSolverParameters) = params.parameters
14+
115
# Xpotrf
216
function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
317
chkuplo(uplo)
418
n = checksquare(A)
519
lda = max(1, stride(A, 2))
620
info = CuVector{Cint}(undef, 1)
7-
params = Ref{cusolverDnParams_t}(C_NULL)
8-
cusolverDnCreateParams(params)
21+
params = CuSolverParameters()
922

1023
function bufferSize()
1124
out_cpu = Ref{Csize_t}(0)
1225
out_gpu = Ref{Csize_t}(0)
13-
cusolverDnXpotrf_bufferSize(dense_handle(), params[], uplo, n,
26+
cusolverDnXpotrf_bufferSize(dense_handle(), params, uplo, n,
1427
T, A, lda, T, out_gpu, out_cpu)
1528
out_gpu[], out_cpu[]
1629
end
1730
with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu
18-
cusolverDnXpotrf(dense_handle(), params[], uplo, n, T, A, lda, T,
31+
cusolverDnXpotrf(dense_handle(), params, uplo, n, T, A, lda, T,
1932
buffer_gpu, sizeof(buffer_gpu), buffer_cpu,
2033
sizeof(buffer_cpu), info)
2134
end
2235

23-
cusolverDnDestroyParams(params[])
2436
flag = @allowscalar info[1]
2537
unsafe_free!(info)
2638
chkargsok(BlasInt(flag))
@@ -36,12 +48,10 @@ function Xpotrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuVecOrMat{T}) whe
3648
lda = max(1, stride(A, 2))
3749
ldb = max(1, stride(B, 2))
3850
info = CuVector{Cint}(undef, 1)
39-
params = Ref{cusolverDnParams_t}(C_NULL)
40-
cusolverDnCreateParams(params)
51+
params = CuSolverParameters()
4152

42-
cusolverDnXpotrs(dense_handle(), params[], uplo, n, nrhs, T, A, lda, T, B, ldb, info)
53+
cusolverDnXpotrs(dense_handle(), params, uplo, n, nrhs, T, A, lda, T, B, ldb, info)
4354

44-
cusolverDnDestroyParams(params[])
4555
flag = @allowscalar info[1]
4656
unsafe_free!(info)
4757
chkargsok(BlasInt(flag))
@@ -53,23 +63,21 @@ function Xgetrf!(A::StridedCuMatrix{T}, ipiv::CuVector{Int64}) where {T <: BlasF
5363
m, n = size(A)
5464
lda = max(1, stride(A, 2))
5565
info = CuVector{Cint}(undef, 1)
56-
params = Ref{cusolverDnParams_t}(C_NULL)
57-
cusolverDnCreateParams(params)
66+
params = CuSolverParameters()
5867

5968
function bufferSize()
6069
out_cpu = Ref{Csize_t}(0)
6170
out_gpu = Ref{Csize_t}(0)
62-
cusolverDnXgetrf_bufferSize(dense_handle(), params[], m, n, T,
71+
cusolverDnXgetrf_bufferSize(dense_handle(), params, m, n, T,
6372
A, lda, T, out_gpu, out_cpu)
6473
out_gpu[], out_cpu[]
6574
end
6675
with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu
67-
cusolverDnXgetrf(dense_handle(), params[], m, n, T, A, lda, ipiv,
76+
cusolverDnXgetrf(dense_handle(), params, m, n, T, A, lda, ipiv,
6877
T, buffer_gpu, sizeof(buffer_gpu), buffer_cpu,
6978
sizeof(buffer_cpu), info)
7079
end
7180

72-
cusolverDnDestroyParams(params[])
7381
flag = @allowscalar info[1]
7482
unsafe_free!(info)
7583
chkargsok(BlasInt(flag))
@@ -90,12 +98,10 @@ function Xgetrs!(trans::Char, A::StridedCuMatrix{T}, ipiv::CuVector{Int64}, B::S
9098
lda = max(1, stride(A, 2))
9199
ldb = max(1, stride(B, 2))
92100
info = CuVector{Cint}(undef, 1)
93-
params = Ref{cusolverDnParams_t}(C_NULL)
94-
cusolverDnCreateParams(params)
101+
params = CuSolverParameters()
95102

96-
cusolverDnXgetrs(dense_handle(), params[], trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, info)
103+
cusolverDnXgetrs(dense_handle(), params, trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, info)
97104

98-
cusolverDnDestroyParams(params[])
99105
flag = @allowscalar info[1]
100106
unsafe_free!(info)
101107
chkargsok(BlasInt(flag))
@@ -107,23 +113,21 @@ function Xgeqrf!(A::StridedCuMatrix{T}, tau::CuVector{T}) where {T <: BlasFloat}
107113
m, n = size(A)
108114
lda = max(1, stride(A, 2))
109115
info = CuVector{Cint}(undef, 1)
110-
params = Ref{cusolverDnParams_t}(C_NULL)
111-
cusolverDnCreateParams(params)
116+
params = CuSolverParameters()
112117

113118
function bufferSize()
114119
out_cpu = Ref{Csize_t}(0)
115120
out_gpu = Ref{Csize_t}(0)
116-
cusolverDnXgeqrf_bufferSize(dense_handle(), params[], m, n, T, A,
121+
cusolverDnXgeqrf_bufferSize(dense_handle(), params, m, n, T, A,
117122
lda, T, tau, T, out_gpu, out_cpu)
118123
out_gpu[], out_cpu[]
119124
end
120125
with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu
121-
cusolverDnXgeqrf(dense_handle(), params[], m, n, T, A,
126+
cusolverDnXgeqrf(dense_handle(), params, m, n, T, A,
122127
lda, T, tau, T, buffer_gpu, sizeof(buffer_gpu),
123128
buffer_cpu, sizeof(buffer_cpu), info)
124129
end
125130

126-
cusolverDnDestroyParams(params[])
127131
flag = @allowscalar info[1]
128132
unsafe_free!(info)
129133
chkargsok(BlasInt(flag))
@@ -216,24 +220,22 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla
216220
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
217221
ldvt = Vt == CU_NULL ? 1 : max(1, stride(Vt, 2))
218222
info = CuVector{Cint}(undef, 1)
219-
params = Ref{cusolverDnParams_t}(C_NULL)
220-
cusolverDnCreateParams(params)
223+
params = CuSolverParameters()
221224

222225
function bufferSize()
223226
out_cpu = Ref{Csize_t}(0)
224227
out_gpu = Ref{Csize_t}(0)
225-
cusolverDnXgesvd_bufferSize(dense_handle(), params[], jobu, jobvt,
228+
cusolverDnXgesvd_bufferSize(dense_handle(), params, jobu, jobvt,
226229
m, n, T, A, lda, R, Σ, T, U, ldu,
227230
T, Vt, ldvt, T, out_gpu, out_cpu)
228231
out_gpu[], out_cpu[]
229232
end
230233
with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu
231-
cusolverDnXgesvd(dense_handle(), params[], jobu, jobvt, m, n, T, A,
234+
cusolverDnXgesvd(dense_handle(), params, jobu, jobvt, m, n, T, A,
232235
lda, R, Σ, T, U, ldu, T, Vt, ldvt, T, buffer_gpu,
233236
sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info)
234237
end
235238

236-
cusolverDnDestroyParams(params[])
237239
flag = @allowscalar info[1]
238240
unsafe_free!(info)
239241
chklapackerror(BlasInt(flag))
@@ -270,25 +272,23 @@ function Xgesvdp!(jobz::Char, econ::Int, A::StridedCuMatrix{T}) where {T <: Blas
270272
ldv = V == CU_NULL ? 1 : max(1, stride(V, 2))
271273
info = CuVector{Cint}(undef, 1)
272274
h_err_sigma = Ref{Cdouble}(0)
273-
params = Ref{cusolverDnParams_t}(C_NULL)
274-
cusolverDnCreateParams(params)
275+
params = CuSolverParameters()
275276

276277
function bufferSize()
277278
out_cpu = Ref{Csize_t}(0)
278279
out_gpu = Ref{Csize_t}(0)
279-
cusolverDnXgesvdp_bufferSize(dense_handle(), params[], jobz, econ, m,
280+
cusolverDnXgesvdp_bufferSize(dense_handle(), params, jobz, econ, m,
280281
n, T, A, lda, R, Σ, T, U, ldu, T, V,
281282
ldv, T, out_gpu, out_cpu)
282283

283284
out_gpu[], out_cpu[]
284285
end
285286
with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu
286-
cusolverDnXgesvdp(dense_handle(), params[], jobz, econ, m, n, T, A, lda, R,
287+
cusolverDnXgesvdp(dense_handle(), params, jobz, econ, m, n, T, A, lda, R,
287288
Σ, T, U, ldu, T, V, ldv, T, buffer_gpu, sizeof(buffer_gpu),
288289
buffer_cpu, sizeof(buffer_cpu), info, h_err_sigma)
289290
end
290291

291-
cusolverDnDestroyParams(params[])
292292
flag = @allowscalar info[1]
293293
unsafe_free!(info)
294294
chklapackerror(BlasInt(flag))
@@ -326,25 +326,23 @@ function Xgesvdr!(jobu::Char, jobv::Char, A::StridedCuMatrix{T}, k::Integer;
326326
ldu = U == CU_NULL ? 1 : max(1, stride(U, 2))
327327
ldv = V == CU_NULL ? 1 : max(1, stride(V, 2))
328328
info = CuVector{Cint}(undef, 1)
329-
params = Ref{cusolverDnParams_t}(C_NULL)
330-
cusolverDnCreateParams(params)
329+
params = CuSolverParameters()
331330

332331
function bufferSize()
333332
out_cpu = Ref{Csize_t}(0)
334333
out_gpu = Ref{Csize_t}(0)
335-
cusolverDnXgesvdr_bufferSize(dense_handle(), params[], jobu, jobv,
334+
cusolverDnXgesvdr_bufferSize(dense_handle(), params, jobu, jobv,
336335
m, n, k, p, niters, T, A, lda, R, Σ, T,
337336
U, ldu, T, V, ldv, T, out_gpu, out_cpu)
338337
out_gpu[], out_cpu[]
339338
end
340339
with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu
341-
cusolverDnXgesvdr(dense_handle(), params[], jobu, jobv, m, n,
340+
cusolverDnXgesvdr(dense_handle(), params, jobu, jobv, m, n,
342341
k, p, niters, T, A, lda, R, Σ, T, U, ldu, T,
343342
V, ldv, T, buffer_gpu, sizeof(buffer_gpu),
344343
buffer_cpu, sizeof(buffer_cpu), info)
345344
end
346345

347-
cusolverDnDestroyParams(params[])
348346
flag = @allowscalar info[1]
349347
unsafe_free!(info)
350348
chklapackerror(BlasInt(flag))
@@ -359,23 +357,21 @@ function Xsyevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: Blas
359357
lda = max(1, stride(A, 2))
360358
info = CuVector{Cint}(undef, 1)
361359
W = CuVector{R}(undef, n)
362-
params = Ref{cusolverDnParams_t}(C_NULL)
363-
cusolverDnCreateParams(params)
360+
params = CuSolverParameters()
364361

365362
function bufferSize()
366363
out_cpu = Ref{Csize_t}(0)
367364
out_gpu = Ref{Csize_t}(0)
368-
cusolverDnXsyevd_bufferSize(dense_handle(), params[], jobz, uplo, n,
365+
cusolverDnXsyevd_bufferSize(dense_handle(), params, jobz, uplo, n,
369366
T, A, lda, R, W, T, out_gpu, out_cpu)
370367
out_gpu[], out_cpu[]
371368
end
372369
with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu
373-
cusolverDnXsyevd(dense_handle(), params[], jobz, uplo, n, T, A,
370+
cusolverDnXsyevd(dense_handle(), params, jobz, uplo, n, T, A,
374371
lda, R, W, T, buffer_gpu, sizeof(buffer_gpu),
375372
buffer_cpu, sizeof(buffer_cpu), info)
376373
end
377374

378-
cusolverDnDestroyParams(params[])
379375
flag = @allowscalar info[1]
380376
unsafe_free!(info)
381377
chkargsok(BlasInt(flag))
@@ -402,24 +398,22 @@ function Xsyevdx!(jobz::Char, range::Char, uplo::Char, A::StridedCuMatrix{T};
402398
vl = Ref{R}(vl)
403399
vu = Ref{R}(vu)
404400
h_meig = Ref{Int64}(0)
405-
params = Ref{cusolverDnParams_t}(C_NULL)
406-
cusolverDnCreateParams(params)
401+
params = CuSolverParameters()
407402

408403
function bufferSize()
409404
out_cpu = Ref{Csize_t}(0)
410405
out_gpu = Ref{Csize_t}(0)
411-
cusolverDnXsyevdx_bufferSize(dense_handle(), params[], jobz, range, uplo, n,
406+
cusolverDnXsyevdx_bufferSize(dense_handle(), params, jobz, range, uplo, n,
412407
T, A, lda, vl, vu, il, iu, h_meig,
413408
R, W, T, out_gpu, out_cpu)
414409
out_gpu[], out_cpu[]
415410
end
416411
with_workspaces(bufferSize()...) do buffer_gpu, buffer_cpu
417-
cusolverDnXsyevdx(dense_handle(), params[], jobz, range, uplo, n, T, A,
412+
cusolverDnXsyevdx(dense_handle(), params, jobz, range, uplo, n, T, A,
418413
lda, vl, vu, il, iu, h_meig, R, W, T, buffer_gpu,
419414
sizeof(buffer_gpu), buffer_cpu, sizeof(buffer_cpu), info)
420415
end
421416

422-
cusolverDnDestroyParams(params[])
423417
flag = @allowscalar info[1]
424418
unsafe_free!(info)
425419
chkargsok(BlasInt(flag))

0 commit comments

Comments
 (0)