Skip to content

Commit 0b16b3a

Browse files
Adding support for in-place qr of views
1 parent 7ebff72 commit 0b16b3a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

lib/cusolver/linalg.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ using LinearAlgebra: Factorization, AbstractQ, QRCompactWY, QRCompactWYQ, QRPack
101101

102102
if VERSION >= v"1.8-"
103103

104-
LinearAlgebra.qr!(A::StridedCuMatrix{T}) where T = QR(geqrf!(A::StridedCuMatrix{T})...)
104+
LinearAlgebra.qr!(A::CuMatrix{T}) where T = CuQR(geqrf!(A::CuMatrix{T})...)
105+
LinearAlgebra.qr!(A::StridedCuMatrix{T}) where T = CuQR(geqrf!(A::StridedCuMatrix{T})...)
105106

106107
# conversions
107108
CuMatrix(F::Union{QR,QRCompactWY}) = CuArray(AbstractArray(F))
@@ -234,7 +235,7 @@ end
234235
Base.similar(Q::CuQRPackedQ, ::Type{T}, dims::Dims{N}) where {T,N} =
235236
CuArray{T,N}(undef, dims)
236237

237-
LinearAlgebra.qr!(A::CuMatrix{T}) where T = CuQR(geqrf!(A::CuMatrix{T})...)
238+
238239
Base.size(A::CuQR) = size(A.factors)
239240
Base.size(A::CuQRPackedQ, dim::Integer) = 0 < dim ? (dim <= 2 ? size(A.factors, 1) : 1) : throw(BoundsError())
240241
CUDA.CuMatrix(A::CuQRPackedQ) = orgqr!(copy(A.factors), A.τ)

0 commit comments

Comments
 (0)