67
67
68
68
# patch JuliaLang/julia#40899 to create a CuArray
69
69
# (see https://github.com/JuliaLang/julia/pull/41331#issuecomment-868374522)
70
- if VERSION >= v " 1.7-"
71
70
_zeros (:: Type{T} , b:: AbstractVector , n:: Integer ) where {T} = CUDA. zeros (T, max (length (b), n))
72
71
_zeros (:: Type{T} , B:: AbstractMatrix , n:: Integer ) where {T} = CUDA. zeros (T, max (size (B, 1 ), n), size (B, 2 ))
73
72
function Base.:\ (F:: Union {LinearAlgebra. LAPACKFactorizations{<: Any ,<: CuArray },
@@ -99,7 +98,6 @@ function Base.:\(F::Union{LinearAlgebra.LAPACKFactorizations{<:Any,<:CuArray},
99
98
# the complete rhs
100
99
return LinearAlgebra. _cut_B (BB, 1 : n)
101
100
end
102
- end
103
101
104
102
# eigenvalues
105
103
@@ -131,8 +129,6 @@ using LinearAlgebra: Factorization, AbstractQ, QRCompactWY, QRCompactWYQ, QRPack
131
129
132
130
# # QR
133
131
134
- if VERSION >= v " 1.8-"
135
-
136
132
LinearAlgebra. qr! (A:: CuMatrix{T} ) where T = QR (geqrf! (A:: CuMatrix{T} )... )
137
133
138
134
# conversions
@@ -229,144 +225,13 @@ LinearAlgebra.rmul!(A::CuVecOrMat{T},
229
225
trA:: Transpose{<:Any,<:QRPackedQ{T,<:CuArray,<:CuArray}} ) where {T<: BlasFloat } =
230
226
ormqr! (' R' , ' T' , parent (trA). factors, parent (adjB). τ, A)
231
227
232
- else
233
-
234
- struct CuQR{T} <: Factorization{T}
235
- factors:: CuMatrix
236
- τ:: CuVector{T}
237
- CuQR {T} (factors:: CuMatrix{T} , τ:: CuVector{T} ) where {T} = new (factors, τ)
238
- end
239
-
240
- struct CuQRPackedQ{T} <: AbstractQ{T}
241
- factors:: CuMatrix{T}
242
- τ:: CuVector{T}
243
- CuQRPackedQ {T} (factors:: CuMatrix{T} , τ:: CuVector{T} ) where {T} = new (factors, τ)
244
- end
245
-
246
- CuQR (factors:: CuMatrix{T} , τ:: CuVector{T} ) where {T} =
247
- CuQR {T} (factors, τ)
248
- CuQRPackedQ (factors:: CuMatrix{T} , τ:: CuVector{T} ) where {T} =
249
- CuQRPackedQ {T} (factors, τ)
250
-
251
- # AbstractQ's `size` is the size of the full matrix,
252
- # while `Matrix(Q)` only gives the compact Q.
253
- # See JuliaLang/julia#26591 and JuliaGPU/CUDA.jl#969.
254
- CuMatrix {T} (Q:: AbstractQ{S} ) where {T,S} = convert (CuArray{T}, Matrix (Q))
255
- CuMatrix {T, B} (Q:: AbstractQ{S} ) where {T, B, S} = CuMatrix {T} (Q)
256
- CuMatrix (Q:: AbstractQ{T} ) where {T} = CuMatrix {T} (Q)
257
- CuArray {T} (Q:: AbstractQ ) where {T} = CuMatrix {T} (Q)
258
- CuArray (Q:: AbstractQ ) = CuMatrix (Q)
259
-
260
- # extracting the full matrix can be done with `collect` (which defaults to `Array`)
261
- function Base. collect (src:: CuQRPackedQ )
262
- dest = similar (src)
263
- copyto! (dest, I)
264
- lmul! (src, dest)
265
- collect (dest)
266
- end
267
-
268
- # avoid the generic similar fallback that returns a CPU array
269
- Base. similar (Q:: CuQRPackedQ , :: Type{T} , dims:: Dims{N} ) where {T,N} =
270
- CuArray {T,N} (undef, dims)
271
-
272
- LinearAlgebra. qr! (A:: CuMatrix{T} ) where T = CuQR (geqrf! (A:: CuMatrix{T} )... )
273
- Base. size (A:: CuQR ) = size (A. factors)
274
- Base. size (A:: CuQRPackedQ , dim:: Integer ) = 0 < dim ? (dim <= 2 ? size (A. factors, 1 ) : 1 ) : throw (BoundsError ())
275
- CUDA. CuMatrix (A:: CuQRPackedQ ) = orgqr! (copy (A. factors), A. τ)
276
- CUDA. CuArray (A:: CuQRPackedQ ) = CuMatrix (A)
277
- Base. Matrix (A:: CuQRPackedQ ) = Matrix (CuMatrix (A))
278
-
279
- function Base. getproperty (A:: CuQR , d:: Symbol )
280
- m, n = size (getfield (A, :factors ))
281
- if d == :R
282
- return triu! (A. factors[1 : min (m, n), 1 : n])
283
- elseif d == :Q
284
- return CuQRPackedQ (A. factors, A. τ)
285
- else
286
- getfield (A, d)
287
- end
288
- end
289
-
290
- # iteration for destructuring into components
291
- Base. iterate (S:: CuQR ) = (S. Q, Val (:R ))
292
- Base. iterate (S:: CuQR , :: Val{:R} ) = (S. R, Val (:done ))
293
- Base. iterate (S:: CuQR , :: Val{:done} ) = nothing
294
-
295
- # Apply changes Q from the left
296
- LinearAlgebra. lmul! (A:: CuQRPackedQ{T} , B:: CuVecOrMat{T} ) where {T<: BlasFloat } =
297
- ormqr! (' L' , ' N' , A. factors, A. τ, B)
298
- LinearAlgebra. lmul! (adjA:: Adjoint{T,<:CuQRPackedQ{T}} , B:: CuVecOrMat{T} ) where {T<: BlasReal } =
299
- ormqr! (' L' , ' T' , parent (adjA). factors, parent (adjA). τ, B)
300
- LinearAlgebra. lmul! (adjA:: Adjoint{T,<:CuQRPackedQ{T}} , B:: CuVecOrMat{T} ) where {T<: BlasComplex } =
301
- ormqr! (' L' , ' C' , parent (adjA). factors, parent (adjA). τ, B)
302
- LinearAlgebra. lmul! (trA:: Transpose{T,<:CuQRPackedQ{T}} , B:: CuVecOrMat{T} ) where {T<: BlasFloat } =
303
- ormqr! (' L' , ' T' , parent (trA). factors, parent (trA). τ, B)
304
-
305
- # Apply changes Q from the right
306
- LinearAlgebra. rmul! (A:: CuVecOrMat{T} , B:: CuQRPackedQ{T} ) where {T<: BlasFloat } =
307
- ormqr! (' R' , ' N' , B. factors, B. τ, A)
308
- LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
309
- adjB:: Adjoint{<:Any,<:CuQRPackedQ{T}} ) where {T<: BlasReal } =
310
- ormqr! (' R' , ' T' , parent (adjB). factors, parent (adjB). τ, A)
311
- LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
312
- adjB:: Adjoint{<:Any,<:CuQRPackedQ{T}} ) where {T<: BlasComplex } =
313
- ormqr! (' R' , ' C' , parent (adjB). factors, parent (adjB). τ, A)
314
- LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
315
- trA:: Transpose{<:Any,<:CuQRPackedQ{T}} ) where {T<: BlasFloat } =
316
- ormqr! (' R' , ' T' , parent (trA). factors, parent (adjB). τ, A)
317
-
318
- function Base. getindex (A:: CuQRPackedQ{T} , i:: Int , j:: Int ) where {T}
319
- assertscalar (" CuQRPackedQ getindex" )
320
- x = CUDA. zeros (T, size (A, 2 ))
321
- x[j] = 1
322
- lmul! (A, x)
323
- return x[i]
324
- end
325
-
326
- function Base. show (io:: IO , F:: CuQR )
327
- println (io, " $(typeof (F)) with factors Q and R:" )
328
- show (io, F. Q)
329
- println (io)
330
- show (io, F. R)
331
- end
332
-
333
- # https://github.com/JuliaLang/julia/pull/32887
334
- LinearAlgebra. det (Q:: CuQRPackedQ{<:Real} ) = isodd (count (! iszero, Q. τ)) ? - 1 : 1
335
- LinearAlgebra. det (Q:: CuQRPackedQ ) = prod (τ -> iszero (τ) ? one (τ) : - sign (τ)^ 2 , Q. τ)
336
-
337
- function LinearAlgebra. ldiv! (_qr:: CuQR , b:: CuVector )
338
- m,n = size (_qr)
339
- _x = UpperTriangular (_qr. R[1 : min (m,n), 1 : n]) \ ((_qr. Q' * b)[1 : n])
340
- b[1 : n] .= _x
341
- unsafe_free! (_x)
342
- return b[1 : n]
343
- end
344
-
345
- function LinearAlgebra. ldiv! (_qr:: CuQR , B:: CuMatrix )
346
- m,n = size (_qr)
347
- _x = UpperTriangular (_qr. R[1 : min (m,n), 1 : n]) \ ((_qr. Q' * B)[1 : n, 1 : size (B, 2 )])
348
- B[1 : n, 1 : size (B, 2 )] .= _x
349
- unsafe_free! (_x)
350
- return B[1 : n, 1 : size (B, 2 )]
351
- end
352
-
353
- function LinearAlgebra. ldiv! (x:: CuArray ,_qr:: CuQR , b:: CuArray )
354
- _x = ldiv! (_qr, b)
355
- x .= vec (_x)
356
- unsafe_free! (_x)
357
- return x
358
- end
359
-
360
- end
361
228
362
229
# # SVD
363
230
364
231
abstract type SVDAlgorithm end
365
232
struct QRAlgorithm <: SVDAlgorithm end
366
233
struct JacobiAlgorithm <: SVDAlgorithm end
367
234
368
- if VERSION >= v " 1.8-"
369
-
370
235
LinearAlgebra. svd! (A:: CuMatrix{T} ; full:: Bool = false ,
371
236
alg:: SVDAlgorithm = JacobiAlgorithm ()) where {T} =
372
237
_svd! (A, full, alg)
@@ -384,47 +249,6 @@ function _svd!(A::CuMatrix{T}, full::Bool, alg::JacobiAlgorithm) where T
384
249
return SVD (U, S, V' )
385
250
end
386
251
387
- else
388
-
389
-
390
- struct CuSVD{T,Tr,A<: AbstractMatrix{T} } <: LinearAlgebra.Factorization{T}
391
- U:: CuMatrix{T}
392
- S:: CuVector{Tr}
393
- V:: A
394
- end
395
-
396
- # iteration for destructuring into components
397
- Base. iterate (S:: CuSVD ) = (S. U, Val (:S ))
398
- Base. iterate (S:: CuSVD , :: Val{:S} ) = (S. S, Val (:V ))
399
- Base. iterate (S:: CuSVD , :: Val{:V} ) = (S. V, Val (:done ))
400
- Base. iterate (S:: CuSVD , :: Val{:done} ) = nothing
401
-
402
- @inline function Base. getproperty (S:: CuSVD , s:: Symbol )
403
- if s === :Vt
404
- return getfield (S, :V )'
405
- else
406
- return getfield (S, s)
407
- end
408
- end
409
-
410
- LinearAlgebra. svd! (A:: CuMatrix{T} ; full:: Bool = false ,
411
- alg:: SVDAlgorithm = JacobiAlgorithm ()) where {T} =
412
- _svd! (A, full, alg)
413
- LinearAlgebra. svd (A:: CuMatrix ; full= false , alg:: SVDAlgorithm = JacobiAlgorithm ()) =
414
- _svd! (copy_cublasfloat (A), full, alg)
415
-
416
- _svd! (A:: CuMatrix{T} , full:: Bool , alg:: SVDAlgorithm ) where T =
417
- throw (ArgumentError (" Unsupported value for `alg` keyword." ))
418
- function _svd! (A:: CuMatrix{T} , full:: Bool , alg:: QRAlgorithm ) where T
419
- U, s, Vt = gesvd! (full ? ' A' : ' S' , full ? ' A' : ' S' , A:: CuMatrix{T} )
420
- return CuSVD (U, s, Vt' )
421
- end
422
- function _svd! (A:: CuMatrix{T} , full:: Bool , alg:: JacobiAlgorithm ) where T
423
- return CuSVD (gesvdj! (' V' , Int (! full), A:: CuMatrix{T} )... )
424
- end
425
-
426
- end
427
-
428
252
LinearAlgebra. svdvals! (A:: CuMatrix{T} ; alg:: SVDAlgorithm = JacobiAlgorithm ()) where {T} =
429
253
_svdvals! (A, alg)
430
254
LinearAlgebra. svdvals (A:: CuMatrix ; alg:: SVDAlgorithm = JacobiAlgorithm ()) =
@@ -443,9 +267,8 @@ function LinearAlgebra.opnorm2(A::CuMatrix{T}) where {T}
443
267
return @allowscalar invoke (LinearAlgebra. opnorm2, Tuple{AbstractMatrix{T}}, A)
444
268
end
445
269
446
- # # LU
447
270
448
- if VERSION >= v " 1.8- "
271
+ # # LU
449
272
450
273
function LinearAlgebra. lu! (A:: StridedCuMatrix{T} , :: RowMaximum ; check:: Bool = true ) where {T}
451
274
lpt = getrf! (A)
472
295
LinearAlgebra. ipiv2perm (v:: CuVector{T} , maxi:: Integer ) where T =
473
296
LinearAlgebra. ipiv2perm (Array (v), maxi)
474
297
475
- end
476
-
477
298
function LinearAlgebra. ldiv! (F:: LU{T,<:StridedCuMatrix{T}} , B:: CuVecOrMat{T} ) where {T}
478
299
return getrs! (' N' , F. factors, F. ipiv, B)
479
300
end
@@ -484,15 +305,14 @@ function LinearAlgebra.ldiv!(F::LU{T,<:StridedCuMatrix{T}}, B::CuVecOrMat{T}) wh
484
305
return getrs! (' N' , F. factors, F. ipiv, B)
485
306
end
486
307
487
- # # cholesky
488
308
489
- if VERSION >= v " 1.8-"
490
- function LinearAlgebra. cholesky (A:: LinearAlgebra.RealHermSymComplexHerm{<:Real,<:CuMatrix} ,
491
- :: Val{false} = Val (false ); check:: Bool = true )
492
- C, info = LinearAlgebra. _chol! (copy (parent (A)), A. uplo == ' U' ? UpperTriangular : LowerTriangular)
493
- return Cholesky (C. data, A. uplo, info)
494
- end
309
+ # # cholesky
495
310
496
- LinearAlgebra. cholcopy (A:: LinearAlgebra.RealHermSymComplexHerm{<:Any,<:CuArray} ) =
497
- copyto! (similar (A, LinearAlgebra. choltype (A)), A)
311
+ function LinearAlgebra. cholesky (A:: LinearAlgebra.RealHermSymComplexHerm{<:Real,<:CuMatrix} ,
312
+ :: Val{false} = Val (false ); check:: Bool = true )
313
+ C, info = LinearAlgebra. _chol! (copy (parent (A)), A. uplo == ' U' ? UpperTriangular : LowerTriangular)
314
+ return Cholesky (C. data, A. uplo, info)
498
315
end
316
+
317
+ LinearAlgebra. cholcopy (A:: LinearAlgebra.RealHermSymComplexHerm{<:Any,<:CuArray} ) =
318
+ copyto! (similar (A, LinearAlgebra. choltype (A)), A)
0 commit comments