@@ -494,6 +494,120 @@ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdj_bufferSize, :cusolverDnS
494
494
end
495
495
end
496
496
497
+ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdjBatched_bufferSize , :cusolverDnSgesvdjBatched , :Float32 , :Float32 ),
498
+ (:cusolverDnDgesvdjBatched_bufferSize , :cusolverDnDgesvdjBatched , :Float64 , :Float64 ),
499
+ (:cusolverDnCgesvdjBatched_bufferSize , :cusolverDnCgesvdjBatched , :ComplexF32 , :Float32 ),
500
+ (:cusolverDnZgesvdjBatched_bufferSize , :cusolverDnZgesvdjBatched , :ComplexF64 , :Float64 ))
501
+ @eval begin
502
+ function gesvdj! (jobz:: Char ,
503
+ A:: StridedCuArray{$elty,3} ;
504
+ tol:: $relty = eps ($ relty),
505
+ max_sweeps:: Int = 100 )
506
+ m, n, batchSize = size (A)
507
+ if m > 32 || n > 32
508
+ throw (ArgumentError (" CUSOLVER's gesvdjBatched currently requires m <=32 and n <= 32" ))
509
+ end
510
+ lda = max (1 , stride (A, 2 ))
511
+
512
+ U = CuArray {$elty} (undef, m, m, batchSize)
513
+ ldu = max (1 , stride (U, 2 ))
514
+
515
+ S = CuArray {$relty} (undef, min (m, n), batchSize)
516
+
517
+ V = CuArray {$elty} (undef, n, n, batchSize)
518
+ ldv = max (1 , stride (V, 2 ))
519
+
520
+ params = Ref {gesvdjInfo_t} (C_NULL )
521
+ cusolverDnCreateGesvdjInfo (params)
522
+ cusolverDnXgesvdjSetTolerance (params[], tol)
523
+ cusolverDnXgesvdjSetMaxSweeps (params[], max_sweeps)
524
+
525
+ function bufferSize ()
526
+ out = Ref {Cint} (0 )
527
+ $ bname (dense_handle (), jobz, m, n, A, lda, S, U, ldu, V, ldv,
528
+ out, params[], batchSize)
529
+ return out[]
530
+ end
531
+
532
+ devinfo = CuArray {Cint} (undef, batchSize)
533
+ with_workspace ($ elty, bufferSize) do work
534
+ $ fname (dense_handle (), jobz, m, n, A, lda, S, U, ldu, V, ldv,
535
+ work, length (work), devinfo, params[], batchSize)
536
+ end
537
+
538
+ info = @allowscalar collect (devinfo)
539
+ unsafe_free! (devinfo)
540
+
541
+ # Double check the solver's exit status
542
+ for i = 1 : batchSize
543
+ chkargsok (BlasInt (info[i]))
544
+ end
545
+
546
+ cusolverDnDestroyGesvdjInfo (params[])
547
+
548
+ U, S, V
549
+ end
550
+ end
551
+ end
552
+
553
+ for (bname, fname, elty, relty) in ((:cusolverDnSgesvdaStridedBatched_bufferSize , :cusolverDnSgesvdaStridedBatched , :Float32 , :Float32 ),
554
+ (:cusolverDnDgesvdaStridedBatched_bufferSize , :cusolverDnDgesvdaStridedBatched , :Float64 , :Float64 ),
555
+ (:cusolverDnCgesvdaStridedBatched_bufferSize , :cusolverDnCgesvdaStridedBatched , :ComplexF32 , :Float32 ),
556
+ (:cusolverDnZgesvdaStridedBatched_bufferSize , :cusolverDnZgesvdaStridedBatched , :ComplexF64 , :Float64 ))
557
+ @eval begin
558
+ function gesvda! (jobz:: Char ,
559
+ A:: StridedCuArray{$elty,3} ;
560
+ rank:: Int = min (size (A,1 ), size (A,2 )))
561
+ m, n, batchSize = size (A)
562
+ if m < n
563
+ throw (ArgumentError (" CUSOLVER's gesvda currently requires m >= n" ))
564
+ # nikopj: I can't find the documentation for this...
565
+ end
566
+ lda = max (1 , stride (A, 2 ))
567
+ strideA = stride (A, 3 )
568
+
569
+ U = CuArray {$elty} (undef, m, rank, batchSize)
570
+ ldu = max (1 , stride (U, 2 ))
571
+ strideU = stride (U, 3 )
572
+
573
+ S = CuArray {$relty} (undef, rank, batchSize)
574
+ strideS = stride (S, 2 )
575
+
576
+ V = CuArray {$elty} (undef, n, rank, batchSize)
577
+ ldv = max (1 , stride (V, 2 ))
578
+ strideV = stride (V, 3 )
579
+
580
+ function bufferSize ()
581
+ out = Ref {Cint} (0 )
582
+ $ bname (dense_handle (), jobz, rank, m, n, A, lda, strideA,
583
+ S, strideS, U, ldu, strideU, V, ldv, strideV,
584
+ out, batchSize)
585
+ return out[]
586
+ end
587
+
588
+ devinfo = CuArray {Cint} (undef, batchSize)
589
+ # residual storage
590
+ h_RnrmF = Array {Cdouble} (undef, batchSize)
591
+
592
+ with_workspace ($ elty, bufferSize) do work
593
+ $ fname (dense_handle (), jobz, rank, m, n, A, lda, strideA,
594
+ S, strideS, U, ldu, strideU, V, ldv, strideV,
595
+ work, length (work), devinfo, h_RnrmF, batchSize)
596
+ end
597
+
598
+ info = @allowscalar collect (devinfo)
599
+ unsafe_free! (devinfo)
600
+
601
+ # Double check the solver's exit status
602
+ for i = 1 : batchSize
603
+ chkargsok (BlasInt (info[i]))
604
+ end
605
+
606
+ U, S, V
607
+ end
608
+ end
609
+ end
610
+
497
611
for (jname, bname, fname, elty, relty) in ((:syevd! , :cusolverDnSsyevd_bufferSize , :cusolverDnSsyevd , :Float32 , :Float32 ),
498
612
(:syevd! , :cusolverDnDsyevd_bufferSize , :cusolverDnDsyevd , :Float64 , :Float64 ),
499
613
(:heevd! , :cusolverDnCheevd_bufferSize , :cusolverDnCheevd , :ComplexF32 , :Float32 ),
0 commit comments