1
+ mutable struct CuSolverParameters
2
+ parameters:: cusolverDnParams_t
3
+
4
+ function CuSolverParameters ()
5
+ parameters_ref = Ref {cusolverDnParams_t} ()
6
+ cusolverDnCreateParams (parameters_ref)
7
+ obj = new (parameters_ref[])
8
+ finalizer (cusolverDnDestroyParams, obj)
9
+ obj
10
+ end
11
+ end
12
+
13
+ Base. unsafe_convert (:: Type{cusolverDnParams_t} , params:: CuSolverParameters ) = params. parameters
14
+
1
15
# Xpotrf
2
16
function Xpotrf! (uplo:: Char , A:: StridedCuMatrix{T} ) where {T <: BlasFloat }
3
17
chkuplo (uplo)
4
18
n = checksquare (A)
5
19
lda = max (1 , stride (A, 2 ))
6
20
info = CuVector {Cint} (undef, 1 )
7
- params = Ref {cusolverDnParams_t} (C_NULL )
8
- cusolverDnCreateParams (params)
21
+ params = CuSolverParameters ()
9
22
10
23
function bufferSize ()
11
24
out_cpu = Ref {Csize_t} (0 )
12
25
out_gpu = Ref {Csize_t} (0 )
13
- cusolverDnXpotrf_bufferSize (dense_handle (), params[] , uplo, n,
26
+ cusolverDnXpotrf_bufferSize (dense_handle (), params, uplo, n,
14
27
T, A, lda, T, out_gpu, out_cpu)
15
28
out_gpu[], out_cpu[]
16
29
end
17
30
with_workspaces (bufferSize ()... ) do buffer_gpu, buffer_cpu
18
- cusolverDnXpotrf (dense_handle (), params[] , uplo, n, T, A, lda, T,
31
+ cusolverDnXpotrf (dense_handle (), params, uplo, n, T, A, lda, T,
19
32
buffer_gpu, sizeof (buffer_gpu), buffer_cpu,
20
33
sizeof (buffer_cpu), info)
21
34
end
22
35
23
- cusolverDnDestroyParams (params[])
24
36
flag = @allowscalar info[1 ]
25
37
unsafe_free! (info)
26
38
chkargsok (BlasInt (flag))
@@ -36,12 +48,10 @@ function Xpotrs!(uplo::Char, A::StridedCuMatrix{T}, B::StridedCuVecOrMat{T}) whe
36
48
lda = max (1 , stride (A, 2 ))
37
49
ldb = max (1 , stride (B, 2 ))
38
50
info = CuVector {Cint} (undef, 1 )
39
- params = Ref {cusolverDnParams_t} (C_NULL )
40
- cusolverDnCreateParams (params)
51
+ params = CuSolverParameters ()
41
52
42
- cusolverDnXpotrs (dense_handle (), params[] , uplo, n, nrhs, T, A, lda, T, B, ldb, info)
53
+ cusolverDnXpotrs (dense_handle (), params, uplo, n, nrhs, T, A, lda, T, B, ldb, info)
43
54
44
- cusolverDnDestroyParams (params[])
45
55
flag = @allowscalar info[1 ]
46
56
unsafe_free! (info)
47
57
chkargsok (BlasInt (flag))
@@ -53,23 +63,21 @@ function Xgetrf!(A::StridedCuMatrix{T}, ipiv::CuVector{Int64}) where {T <: BlasF
53
63
m, n = size (A)
54
64
lda = max (1 , stride (A, 2 ))
55
65
info = CuVector {Cint} (undef, 1 )
56
- params = Ref {cusolverDnParams_t} (C_NULL )
57
- cusolverDnCreateParams (params)
66
+ params = CuSolverParameters ()
58
67
59
68
function bufferSize ()
60
69
out_cpu = Ref {Csize_t} (0 )
61
70
out_gpu = Ref {Csize_t} (0 )
62
- cusolverDnXgetrf_bufferSize (dense_handle (), params[] , m, n, T,
71
+ cusolverDnXgetrf_bufferSize (dense_handle (), params, m, n, T,
63
72
A, lda, T, out_gpu, out_cpu)
64
73
out_gpu[], out_cpu[]
65
74
end
66
75
with_workspaces (bufferSize ()... ) do buffer_gpu, buffer_cpu
67
- cusolverDnXgetrf (dense_handle (), params[] , m, n, T, A, lda, ipiv,
76
+ cusolverDnXgetrf (dense_handle (), params, m, n, T, A, lda, ipiv,
68
77
T, buffer_gpu, sizeof (buffer_gpu), buffer_cpu,
69
78
sizeof (buffer_cpu), info)
70
79
end
71
80
72
- cusolverDnDestroyParams (params[])
73
81
flag = @allowscalar info[1 ]
74
82
unsafe_free! (info)
75
83
chkargsok (BlasInt (flag))
@@ -90,12 +98,10 @@ function Xgetrs!(trans::Char, A::StridedCuMatrix{T}, ipiv::CuVector{Int64}, B::S
90
98
lda = max (1 , stride (A, 2 ))
91
99
ldb = max (1 , stride (B, 2 ))
92
100
info = CuVector {Cint} (undef, 1 )
93
- params = Ref {cusolverDnParams_t} (C_NULL )
94
- cusolverDnCreateParams (params)
101
+ params = CuSolverParameters ()
95
102
96
- cusolverDnXgetrs (dense_handle (), params[] , trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, info)
103
+ cusolverDnXgetrs (dense_handle (), params, trans, n, nrhs, T, A, lda, ipiv, T, B, ldb, info)
97
104
98
- cusolverDnDestroyParams (params[])
99
105
flag = @allowscalar info[1 ]
100
106
unsafe_free! (info)
101
107
chkargsok (BlasInt (flag))
@@ -107,23 +113,21 @@ function Xgeqrf!(A::StridedCuMatrix{T}, tau::CuVector{T}) where {T <: BlasFloat}
107
113
m, n = size (A)
108
114
lda = max (1 , stride (A, 2 ))
109
115
info = CuVector {Cint} (undef, 1 )
110
- params = Ref {cusolverDnParams_t} (C_NULL )
111
- cusolverDnCreateParams (params)
116
+ params = CuSolverParameters ()
112
117
113
118
function bufferSize ()
114
119
out_cpu = Ref {Csize_t} (0 )
115
120
out_gpu = Ref {Csize_t} (0 )
116
- cusolverDnXgeqrf_bufferSize (dense_handle (), params[] , m, n, T, A,
121
+ cusolverDnXgeqrf_bufferSize (dense_handle (), params, m, n, T, A,
117
122
lda, T, tau, T, out_gpu, out_cpu)
118
123
out_gpu[], out_cpu[]
119
124
end
120
125
with_workspaces (bufferSize ()... ) do buffer_gpu, buffer_cpu
121
- cusolverDnXgeqrf (dense_handle (), params[] , m, n, T, A,
126
+ cusolverDnXgeqrf (dense_handle (), params, m, n, T, A,
122
127
lda, T, tau, T, buffer_gpu, sizeof (buffer_gpu),
123
128
buffer_cpu, sizeof (buffer_cpu), info)
124
129
end
125
130
126
- cusolverDnDestroyParams (params[])
127
131
flag = @allowscalar info[1 ]
128
132
unsafe_free! (info)
129
133
chkargsok (BlasInt (flag))
@@ -216,24 +220,22 @@ function Xgesvd!(jobu::Char, jobvt::Char, A::StridedCuMatrix{T}) where {T <: Bla
216
220
ldu = U == CU_NULL ? 1 : max (1 , stride (U, 2 ))
217
221
ldvt = Vt == CU_NULL ? 1 : max (1 , stride (Vt, 2 ))
218
222
info = CuVector {Cint} (undef, 1 )
219
- params = Ref {cusolverDnParams_t} (C_NULL )
220
- cusolverDnCreateParams (params)
223
+ params = CuSolverParameters ()
221
224
222
225
function bufferSize ()
223
226
out_cpu = Ref {Csize_t} (0 )
224
227
out_gpu = Ref {Csize_t} (0 )
225
- cusolverDnXgesvd_bufferSize (dense_handle (), params[] , jobu, jobvt,
228
+ cusolverDnXgesvd_bufferSize (dense_handle (), params, jobu, jobvt,
226
229
m, n, T, A, lda, R, Σ, T, U, ldu,
227
230
T, Vt, ldvt, T, out_gpu, out_cpu)
228
231
out_gpu[], out_cpu[]
229
232
end
230
233
with_workspaces (bufferSize ()... ) do buffer_gpu, buffer_cpu
231
- cusolverDnXgesvd (dense_handle (), params[] , jobu, jobvt, m, n, T, A,
234
+ cusolverDnXgesvd (dense_handle (), params, jobu, jobvt, m, n, T, A,
232
235
lda, R, Σ, T, U, ldu, T, Vt, ldvt, T, buffer_gpu,
233
236
sizeof (buffer_gpu), buffer_cpu, sizeof (buffer_cpu), info)
234
237
end
235
238
236
- cusolverDnDestroyParams (params[])
237
239
flag = @allowscalar info[1 ]
238
240
unsafe_free! (info)
239
241
chklapackerror (BlasInt (flag))
@@ -270,25 +272,23 @@ function Xgesvdp!(jobz::Char, econ::Int, A::StridedCuMatrix{T}) where {T <: Blas
270
272
ldv = V == CU_NULL ? 1 : max (1 , stride (V, 2 ))
271
273
info = CuVector {Cint} (undef, 1 )
272
274
h_err_sigma = Ref {Cdouble} (0 )
273
- params = Ref {cusolverDnParams_t} (C_NULL )
274
- cusolverDnCreateParams (params)
275
+ params = CuSolverParameters ()
275
276
276
277
function bufferSize ()
277
278
out_cpu = Ref {Csize_t} (0 )
278
279
out_gpu = Ref {Csize_t} (0 )
279
- cusolverDnXgesvdp_bufferSize (dense_handle (), params[] , jobz, econ, m,
280
+ cusolverDnXgesvdp_bufferSize (dense_handle (), params, jobz, econ, m,
280
281
n, T, A, lda, R, Σ, T, U, ldu, T, V,
281
282
ldv, T, out_gpu, out_cpu)
282
283
283
284
out_gpu[], out_cpu[]
284
285
end
285
286
with_workspaces (bufferSize ()... ) do buffer_gpu, buffer_cpu
286
- cusolverDnXgesvdp (dense_handle (), params[] , jobz, econ, m, n, T, A, lda, R,
287
+ cusolverDnXgesvdp (dense_handle (), params, jobz, econ, m, n, T, A, lda, R,
287
288
Σ, T, U, ldu, T, V, ldv, T, buffer_gpu, sizeof (buffer_gpu),
288
289
buffer_cpu, sizeof (buffer_cpu), info, h_err_sigma)
289
290
end
290
291
291
- cusolverDnDestroyParams (params[])
292
292
flag = @allowscalar info[1 ]
293
293
unsafe_free! (info)
294
294
chklapackerror (BlasInt (flag))
@@ -326,25 +326,23 @@ function Xgesvdr!(jobu::Char, jobv::Char, A::StridedCuMatrix{T}, k::Integer;
326
326
ldu = U == CU_NULL ? 1 : max (1 , stride (U, 2 ))
327
327
ldv = V == CU_NULL ? 1 : max (1 , stride (V, 2 ))
328
328
info = CuVector {Cint} (undef, 1 )
329
- params = Ref {cusolverDnParams_t} (C_NULL )
330
- cusolverDnCreateParams (params)
329
+ params = CuSolverParameters ()
331
330
332
331
function bufferSize ()
333
332
out_cpu = Ref {Csize_t} (0 )
334
333
out_gpu = Ref {Csize_t} (0 )
335
- cusolverDnXgesvdr_bufferSize (dense_handle (), params[] , jobu, jobv,
334
+ cusolverDnXgesvdr_bufferSize (dense_handle (), params, jobu, jobv,
336
335
m, n, k, p, niters, T, A, lda, R, Σ, T,
337
336
U, ldu, T, V, ldv, T, out_gpu, out_cpu)
338
337
out_gpu[], out_cpu[]
339
338
end
340
339
with_workspaces (bufferSize ()... ) do buffer_gpu, buffer_cpu
341
- cusolverDnXgesvdr (dense_handle (), params[] , jobu, jobv, m, n,
340
+ cusolverDnXgesvdr (dense_handle (), params, jobu, jobv, m, n,
342
341
k, p, niters, T, A, lda, R, Σ, T, U, ldu, T,
343
342
V, ldv, T, buffer_gpu, sizeof (buffer_gpu),
344
343
buffer_cpu, sizeof (buffer_cpu), info)
345
344
end
346
345
347
- cusolverDnDestroyParams (params[])
348
346
flag = @allowscalar info[1 ]
349
347
unsafe_free! (info)
350
348
chklapackerror (BlasInt (flag))
@@ -359,23 +357,21 @@ function Xsyevd!(jobz::Char, uplo::Char, A::StridedCuMatrix{T}) where {T <: Blas
359
357
lda = max (1 , stride (A, 2 ))
360
358
info = CuVector {Cint} (undef, 1 )
361
359
W = CuVector {R} (undef, n)
362
- params = Ref {cusolverDnParams_t} (C_NULL )
363
- cusolverDnCreateParams (params)
360
+ params = CuSolverParameters ()
364
361
365
362
function bufferSize ()
366
363
out_cpu = Ref {Csize_t} (0 )
367
364
out_gpu = Ref {Csize_t} (0 )
368
- cusolverDnXsyevd_bufferSize (dense_handle (), params[] , jobz, uplo, n,
365
+ cusolverDnXsyevd_bufferSize (dense_handle (), params, jobz, uplo, n,
369
366
T, A, lda, R, W, T, out_gpu, out_cpu)
370
367
out_gpu[], out_cpu[]
371
368
end
372
369
with_workspaces (bufferSize ()... ) do buffer_gpu, buffer_cpu
373
- cusolverDnXsyevd (dense_handle (), params[] , jobz, uplo, n, T, A,
370
+ cusolverDnXsyevd (dense_handle (), params, jobz, uplo, n, T, A,
374
371
lda, R, W, T, buffer_gpu, sizeof (buffer_gpu),
375
372
buffer_cpu, sizeof (buffer_cpu), info)
376
373
end
377
374
378
- cusolverDnDestroyParams (params[])
379
375
flag = @allowscalar info[1 ]
380
376
unsafe_free! (info)
381
377
chkargsok (BlasInt (flag))
@@ -402,24 +398,22 @@ function Xsyevdx!(jobz::Char, range::Char, uplo::Char, A::StridedCuMatrix{T};
402
398
vl = Ref {R} (vl)
403
399
vu = Ref {R} (vu)
404
400
h_meig = Ref {Int64} (0 )
405
- params = Ref {cusolverDnParams_t} (C_NULL )
406
- cusolverDnCreateParams (params)
401
+ params = CuSolverParameters ()
407
402
408
403
function bufferSize ()
409
404
out_cpu = Ref {Csize_t} (0 )
410
405
out_gpu = Ref {Csize_t} (0 )
411
- cusolverDnXsyevdx_bufferSize (dense_handle (), params[] , jobz, range, uplo, n,
406
+ cusolverDnXsyevdx_bufferSize (dense_handle (), params, jobz, range, uplo, n,
412
407
T, A, lda, vl, vu, il, iu, h_meig,
413
408
R, W, T, out_gpu, out_cpu)
414
409
out_gpu[], out_cpu[]
415
410
end
416
411
with_workspaces (bufferSize ()... ) do buffer_gpu, buffer_cpu
417
- cusolverDnXsyevdx (dense_handle (), params[] , jobz, range, uplo, n, T, A,
412
+ cusolverDnXsyevdx (dense_handle (), params, jobz, range, uplo, n, T, A,
418
413
lda, vl, vu, il, iu, h_meig, R, W, T, buffer_gpu,
419
414
sizeof (buffer_gpu), buffer_cpu, sizeof (buffer_cpu), info)
420
415
end
421
416
422
- cusolverDnDestroyParams (params[])
423
417
flag = @allowscalar info[1 ]
424
418
unsafe_free! (info)
425
419
chkargsok (BlasInt (flag))
0 commit comments