Skip to content

Commit 15dbba8

Browse files
authored
[CUSOLVER] Support A \ b for rectangular matrices (#1802)
1 parent a218b9c commit 15dbba8

File tree

2 files changed

+87
-19
lines changed

2 files changed

+87
-19
lines changed

lib/cusolver/linalg.jl

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
using LinearAlgebra
44
using LinearAlgebra: BlasComplex, BlasFloat, BlasReal
5-
using ..CUBLAS: CublasFloat
5+
using ..CUBLAS: CublasFloat, trsm!
66

77
function copy_cublasfloat(As...)
88
eltypes = eltype.(As)
@@ -20,17 +20,49 @@ _copywitheltype(::Type{T}, As...) where {T} = map(A -> copyto!(similar(A, T), A)
2020

2121
# matrix division
2222

23-
const CuMatOrAdj{T} = Union{CuMatrix,
23+
const CuMatOrAdj{T} = Union{CuMatrix{T},
2424
LinearAlgebra.Adjoint{T, <:CuMatrix{T}},
2525
LinearAlgebra.Transpose{T, <:CuMatrix{T}}}
26-
const CuOrAdj{T} = Union{CuVecOrMat,
26+
const CuOrAdj{T} = Union{CuVecOrMat{T},
2727
LinearAlgebra.Adjoint{T, <:CuVecOrMat{T}},
2828
LinearAlgebra.Transpose{T, <:CuVecOrMat{T}}}
2929

3030
function Base.:\(_A::CuMatOrAdj, _B::CuOrAdj)
3131
A, B = copy_cublasfloat(_A, _B)
32-
A, ipiv = CUSOLVER.getrf!(A)
33-
return CUSOLVER.getrs!('N', A, ipiv, B)
32+
T = eltype(A)
33+
n,m = size(A)
34+
if n < m
35+
# LQ decomposition
36+
At = CuMatrix(A')
37+
F, tau = CUSOLVER.geqrf!(At) # A = RᴴQᴴ
38+
if B isa CuVector{T}
39+
CUBLAS.trsv!('U', 'C', 'N', view(F,1:n,1:n), B)
40+
X = CUDA.zeros(T, m)
41+
view(X, 1:n) .= B
42+
else
43+
CUBLAS.trsm!('L', 'U', 'C', 'N', one(T), view(F,1:n,1:n), B)
44+
p = size(B, 2)
45+
X = CUDA.zeros(T, m, p)
46+
view(X, 1:n, :) .= B
47+
end
48+
CUSOLVER.ormqr!('L', 'N', F, tau, X)
49+
elseif n == m
50+
# LU decomposition with partial pivoting
51+
F, p, info = CUSOLVER.getrf!(A) # PA = LU
52+
X = CUSOLVER.getrs!('N', F, p, B)
53+
else
54+
# QR decomposition
55+
F, tau = CUSOLVER.geqrf!(A) # A = QR
56+
CUSOLVER.ormqr!('L', T <: Real ? 'T' : 'C', F, tau, B)
57+
if B isa CuVector{T}
58+
X = B[1:m]
59+
CUBLAS.trsv!('U', 'N', 'N', view(F,1:m,1:m), X)
60+
else
61+
X = B[1:m,:]
62+
CUBLAS.trsm!('L', 'U', 'N', 'N', one(T), view(F,1:m,1:m), X)
63+
end
64+
end
65+
return X
3466
end
3567

3668
# patch JuliaLang/julia#40899 to create a CuArray

test/cusolver/dense.jl

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -630,18 +630,54 @@ end
630630
], elty2 in [
631631
Float16, Float32, Float64, ComplexF16, ComplexF32, ComplexF64, Int32, Int64, Complex{Int32}, Complex{Int64}
632632
]
633-
A = rand(elty1,n,n)
634-
B = rand(elty2,n,n)
635-
b = rand(elty2,n)
636-
d_A = CuArray(A)
637-
d_B = CuArray(B)
638-
d_b = CuArray(b)
639-
cublasfloat = promote_type(Float32, promote_type(elty1, elty2))
640-
Af = cublasfloat.(A)
641-
Bf = cublasfloat.(B)
642-
bf = cublasfloat.(b)
643-
@test Array(d_A \ d_B) (Af \ Bf)
644-
@test Array(d_A \ d_b) (Af \ bf)
645-
@inferred d_A \ d_B
646-
@inferred d_A \ d_b
633+
@testset "Square linear systems" begin
634+
A = rand(elty1,n,n)
635+
B = rand(elty2,n,5)
636+
b = rand(elty2,n)
637+
d_A = CuArray(A)
638+
d_B = CuArray(B)
639+
d_b = CuArray(b)
640+
cublasfloat = promote_type(Float32, promote_type(elty1, elty2))
641+
Af = cublasfloat.(A)
642+
Bf = cublasfloat.(B)
643+
bf = cublasfloat.(b)
644+
@test Array(d_A \ d_B) (Af \ Bf)
645+
@test Array(d_A \ d_b) (Af \ bf)
646+
@inferred d_A \ d_B
647+
@inferred d_A \ d_b
648+
end
649+
650+
@testset "Overdetermined linear systems" begin
651+
A = rand(elty1,m,n)
652+
B = rand(elty2,m,5)
653+
b = rand(elty2,m)
654+
d_A = CuArray(A)
655+
d_B = CuArray(B)
656+
d_b = CuArray(b)
657+
cublasfloat = promote_type(Float32, promote_type(elty1, elty2))
658+
Af = cublasfloat.(A)
659+
Bf = cublasfloat.(B)
660+
bf = cublasfloat.(b)
661+
@test Array(d_A \ d_B) (Af \ Bf)
662+
@test Array(d_A \ d_b) (Af \ bf)
663+
@inferred d_A \ d_B
664+
@inferred d_A \ d_b
665+
end
666+
667+
@testset "Underdetermined linear systems" begin
668+
A = rand(elty1,n,m)
669+
B = rand(elty2,n,5)
670+
b = rand(elty2,n)
671+
d_A = CuArray(A)
672+
d_B = CuArray(B)
673+
d_b = CuArray(b)
674+
cublasfloat = promote_type(Float32, promote_type(elty1, elty2))
675+
Af = cublasfloat.(A)
676+
Bf = cublasfloat.(B)
677+
bf = cublasfloat.(b)
678+
@test Array(d_A \ d_B) (Af \ Bf)
679+
@test Array(d_A \ d_b) (Af \ bf)
680+
@inferred d_A \ d_B
681+
@inferred d_A \ d_b
682+
end
647683
end

0 commit comments

Comments
 (0)