Skip to content

Commit d9ca07a

Browse files
Adding support for in place QR of views in LinearAlgebra
1 parent 455c49b commit d9ca07a

File tree

1 file changed

+51
-42
lines changed

1 file changed

+51
-42
lines changed

lib/cusolver/linalg.jl

Lines changed: 51 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@ _copywitheltype(::Type{T}, As...) where {T} = map(A -> copyto!(similar(A, T), A)
2020

2121
# matrix division
2222

23-
const CuMatOrAdj{T} = Union{CuMatrix,
24-
LinearAlgebra.Adjoint{T, <:CuMatrix{T}},
25-
LinearAlgebra.Transpose{T, <:CuMatrix{T}}}
26-
const CuOrAdj{T} = Union{CuVecOrMat,
27-
LinearAlgebra.Adjoint{T, <:CuVecOrMat{T}},
28-
LinearAlgebra.Transpose{T, <:CuVecOrMat{T}}}
23+
const CuMatOrAdj{T} = Union{StridedCuMatrix,
24+
LinearAlgebra.Adjoint{T, <:StridedCuMatrix{T}},
25+
LinearAlgebra.Transpose{T, <:StridedCuMatrix{T}}}
26+
const CuOrAdj{T} = Union{StridedCuVector,
27+
LinearAlgebra.Adjoint{T, <:StridedCuVector{T}},
28+
LinearAlgebra.Transpose{T, <:StridedCuVector{T}},
29+
StridedCuMatrix,
30+
LinearAlgebra.Adjoint{T, <:StridedCuMatrix{T}},
31+
LinearAlgebra.Transpose{T, <:StridedCuMatrix{T}}}
2932

3033
function Base.:\(_A::CuMatOrAdj, _B::CuOrAdj)
3134
A, B = copy_cublasfloat(_A, _B)
@@ -101,31 +104,34 @@ using LinearAlgebra: Factorization, AbstractQ, QRCompactWY, QRCompactWYQ, QRPack
101104

102105
if VERSION >= v"1.8-"
103106

107+
108+
104109
LinearAlgebra.qr!(A::StridedCuMatrix{T}) where T = QR(geqrf!(A::StridedCuMatrix{T})...)
105110

111+
106112
# conversions
107113
CuMatrix(F::Union{QR,QRCompactWY}) = CuArray(AbstractArray(F))
108114
CuArray(F::Union{QR,QRCompactWY}) = CuMatrix(F)
109115
CuMatrix(F::QRPivoted) = CuArray(AbstractArray(F))
110116
CuArray(F::QRPivoted) = CuMatrix(F)
111117

112-
function LinearAlgebra.ldiv!(_qr::QR, b::CuVector)
118+
function LinearAlgebra.ldiv!(_qr::QR, b::StridedCuVector)
113119
m,n = size(_qr)
114120
_x = UpperTriangular(_qr.R[1:min(m,n), 1:n]) \ ((_qr.Q' * b)[1:n])
115121
b[1:n] .= _x
116122
unsafe_free!(_x)
117123
return b[1:n]
118124
end
119125

120-
function LinearAlgebra.ldiv!(_qr::QR, B::CuMatrix)
126+
function LinearAlgebra.ldiv!(_qr::QR, B::StridedCuMatrix)
121127
m,n = size(_qr)
122128
_x = UpperTriangular(_qr.R[1:min(m,n), 1:n]) \ ((_qr.Q' * B)[1:n, 1:size(B, 2)])
123129
B[1:n, 1:size(B, 2)] .= _x
124130
unsafe_free!(_x)
125131
return B[1:n, 1:size(B, 2)]
126132
end
127133

128-
function LinearAlgebra.ldiv!(x::CuArray, _qr::QR, b::CuArray)
134+
function LinearAlgebra.ldiv!(x::StridedCuArray, _qr::QR, b::StridedCuArray)
129135
_x = ldiv!(_qr, b)
130136
x .= vec(_x)
131137
unsafe_free!(_x)
@@ -146,71 +152,74 @@ CuMatrix{T}(Q::QRCompactWYQ) where {T} = error("QRCompactWY format is not suppor
146152
Matrix{T}(Q::QRPackedQ{S,<:CuArray,<:CuArray}) where {T,S} = Array(CuMatrix{T}(Q))
147153
Matrix{T}(Q::QRCompactWYQ{S,<:CuArray,<:CuArray}) where {T,S} = Array(CuMatrix{T}(Q))
148154

155+
156+
149157
# extracting the full matrix can be done with `collect` (which defaults to `Array`)
150-
function Base.collect(src::Union{QRPackedQ{<:Any,<:CuArray,<:CuArray},
151-
QRCompactWYQ{<:Any,<:CuArray,<:CuArray}})
158+
function Base.collect(src::Union{QRPackedQ{<:Any,<:StridedCuArray,<:StridedCuArray},
159+
QRCompactWYQ{<:Any,<:StridedCuArray,<:StridedCuArray}})
152160
dest = similar(src)
153161
copyto!(dest, I)
154162
lmul!(src, dest)
155163
collect(dest)
156164
end
157165

158166
# avoid the generic similar fallback that returns a CPU array
159-
Base.similar(Q::Union{QRPackedQ{<:Any,<:CuArray,<:CuArray},
160-
QRCompactWYQ{<:Any,<:CuArray,<:CuArray}},
167+
Base.similar(Q::Union{QRPackedQ{<:Any,<:StridedCuArray,<:StridedCuArray},
168+
QRCompactWYQ{<:Any,<:StridedCuArray,<:StridedCuArray}},
161169
::Type{T}, dims::Dims{N}) where {T,N} =
162170
CuArray{T,N}(undef, dims)
163171

164-
function Base.getindex(Q::QRPackedQ{<:Any, <:CuArray}, ::Colon, j::Int)
172+
function Base.getindex(Q::QRPackedQ{<:Any, <:StridedCuArray}, ::Colon, j::Int)
165173
y = CUDA.zeros(eltype(Q), size(Q, 2))
166174
y[j] = 1
167175
lmul!(Q, y)
168176
end
169177

178+
170179
# multiplication by Q
171-
LinearAlgebra.lmul!(A::QRPackedQ{T,<:CuArray,<:CuArray},
180+
LinearAlgebra.lmul!(A::QRPackedQ{T,<:StridedCuArray,<:StridedCuArray},
172181
B::CuVecOrMat{T}) where {T<:BlasFloat} =
173182
ormqr!('L', 'N', A.factors, A.τ, B)
174-
LinearAlgebra.lmul!(adjA::Adjoint{T,<:QRPackedQ{T,<:CuArray,<:CuArray}},
183+
LinearAlgebra.lmul!(adjA::Adjoint{T,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}},
175184
B::CuVecOrMat{T}) where {T<:BlasReal} =
176185
ormqr!('L', 'T', parent(adjA).factors, parent(adjA).τ, B)
177-
LinearAlgebra.lmul!(adjA::Adjoint{T,<:QRPackedQ{T,<:CuArray,<:CuArray}},
186+
LinearAlgebra.lmul!(adjA::Adjoint{T,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}},
178187
B::CuVecOrMat{T}) where {T<:BlasComplex} =
179188
ormqr!('L', 'C', parent(adjA).factors, parent(adjA).τ, B)
180-
LinearAlgebra.lmul!(trA::Transpose{T,<:QRPackedQ{T,<:CuArray,<:CuArray}},
189+
LinearAlgebra.lmul!(trA::Transpose{T,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}},
181190
B::CuVecOrMat{T}) where {T<:BlasFloat} =
182191
ormqr!('L', 'T', parent(trA).factors, parent(trA).τ, B)
183192

184193
LinearAlgebra.rmul!(A::CuVecOrMat{T},
185-
B::QRPackedQ{T,<:CuArray,<:CuArray}) where {T<:BlasFloat} =
194+
B::QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}) where {T<:BlasFloat} =
186195
ormqr!('R', 'N', B.factors, B.τ, A)
187196
LinearAlgebra.rmul!(A::CuVecOrMat{T},
188-
adjB::Adjoint{<:Any,<:QRPackedQ{T,<:CuArray,<:CuArray}}) where {T<:BlasReal} =
197+
adjB::Adjoint{<:Any,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}}) where {T<:BlasReal} =
189198
ormqr!('R', 'T', parent(adjB).factors, parent(adjB).τ, A)
190199
LinearAlgebra.rmul!(A::CuVecOrMat{T},
191-
adjB::Adjoint{<:Any,<:QRPackedQ{T,<:CuArray,<:CuArray}}) where {T<:BlasComplex} =
200+
adjB::Adjoint{<:Any,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}}) where {T<:BlasComplex} =
192201
ormqr!('R', 'C', parent(adjB).factors, parent(adjB).τ, A)
193202
LinearAlgebra.rmul!(A::CuVecOrMat{T},
194-
trA::Transpose{<:Any,<:QRPackedQ{T,<:CuArray,<:CuArray}}) where {T<:BlasFloat} =
203+
trA::Transpose{<:Any,<:QRPackedQ{T,<:StridedCuArray,<:StridedCuArray}}) where {T<:BlasFloat} =
195204
ormqr!('R', 'T', parent(trA).factors, parent(adjB).τ, A)
196205

197206
else
198207

199208
struct CuQR{T} <: Factorization{T}
200-
factors::CuMatrix
201-
τ::CuVector{T}
202-
CuQR{T}(factors::CuMatrix{T}, τ::CuVector{T}) where {T} = new(factors, τ)
209+
factors::StridedCuMatrix
210+
τ::StridedCuVector{T}
211+
CuQR{T}(factors::StridedCuMatrix{T}, τ::StridedCuVector{T}) where {T} = new(factors, τ)
203212
end
204213

205214
struct CuQRPackedQ{T} <: AbstractQ{T}
206-
factors::CuMatrix{T}
207-
τ::CuVector{T}
208-
CuQRPackedQ{T}(factors::CuMatrix{T}, τ::CuVector{T}) where {T} = new(factors, τ)
215+
factors::StridedCuMatrix{T}
216+
τ::StridedCuVector{T}
217+
CuQRPackedQ{T}(factors::StridedCuMatrix{T}, τ::StridedCuVector{T}) where {T} = new(factors, τ)
209218
end
210219

211-
CuQR(factors::CuMatrix{T}, τ::CuVector{T}) where {T} =
220+
CuQR(factors::StridedCuMatrix{T}, τ::StridedCuVector{T}) where {T} =
212221
CuQR{T}(factors, τ)
213-
CuQRPackedQ(factors::CuMatrix{T}, τ::CuVector{T}) where {T} =
222+
CuQRPackedQ(factors::StridedCuMatrix{T}, τ::StridedCuVector{T}) where {T} =
214223
CuQRPackedQ{T}(factors, τ)
215224

216225
# AbstractQ's `size` is the size of the full matrix,
@@ -245,7 +254,7 @@ Base.Matrix(A::CuQRPackedQ) = Matrix(CuMatrix(A))
245254
function Base.getproperty(A::CuQR, d::Symbol)
246255
m, n = size(getfield(A, :factors))
247256
if d == :R
248-
return triu!(A.factors[1:min(m, n), 1:n])
257+
return triu!(view(A.factors,1:min(m, n), 1:n))
249258
elseif d == :Q
250259
return CuQRPackedQ(A.factors, A.τ)
251260
else
@@ -259,25 +268,25 @@ Base.iterate(S::CuQR, ::Val{:R}) = (S.R, Val(:done))
259268
Base.iterate(S::CuQR, ::Val{:done}) = nothing
260269

261270
# Apply changes Q from the left
262-
LinearAlgebra.lmul!(A::CuQRPackedQ{T}, B::CuVecOrMat{T}) where {T<:BlasFloat} =
271+
LinearAlgebra.lmul!(A::CuQRPackedQ{T}, B::StridedCuVecOrMat{T}) where {T<:BlasFloat} =
263272
ormqr!('L', 'N', A.factors, A.τ, B)
264-
LinearAlgebra.lmul!(adjA::Adjoint{T,<:CuQRPackedQ{T}}, B::CuVecOrMat{T}) where {T<:BlasReal} =
273+
LinearAlgebra.lmul!(adjA::Adjoint{T,<:CuQRPackedQ{T}}, B::StridedCuVecOrMat{T}) where {T<:BlasReal} =
265274
ormqr!('L', 'T', parent(adjA).factors, parent(adjA).τ, B)
266-
LinearAlgebra.lmul!(adjA::Adjoint{T,<:CuQRPackedQ{T}}, B::CuVecOrMat{T}) where {T<:BlasComplex} =
275+
LinearAlgebra.lmul!(adjA::Adjoint{T,<:CuQRPackedQ{T}}, B::StridedCuVecOrMat{T}) where {T<:BlasComplex} =
267276
ormqr!('L', 'C', parent(adjA).factors, parent(adjA).τ, B)
268-
LinearAlgebra.lmul!(trA::Transpose{T,<:CuQRPackedQ{T}}, B::CuVecOrMat{T}) where {T<:BlasFloat} =
277+
LinearAlgebra.lmul!(trA::Transpose{T,<:CuQRPackedQ{T}}, B::StridedCuVecOrMat{T}) where {T<:BlasFloat} =
269278
ormqr!('L', 'T', parent(trA).factors, parent(trA).τ, B)
270279

271280
# Apply changes Q from the right
272-
LinearAlgebra.rmul!(A::CuVecOrMat{T}, B::CuQRPackedQ{T}) where {T<:BlasFloat} =
281+
LinearAlgebra.rmul!(A::StridedCuVecOrMat{T}, B::CuQRPackedQ{T}) where {T<:BlasFloat} =
273282
ormqr!('R', 'N', B.factors, B.τ, A)
274-
LinearAlgebra.rmul!(A::CuVecOrMat{T},
283+
LinearAlgebra.rmul!(A::StridedCuVecOrMat{T},
275284
adjB::Adjoint{<:Any,<:CuQRPackedQ{T}}) where {T<:BlasReal} =
276285
ormqr!('R', 'T', parent(adjB).factors, parent(adjB).τ, A)
277-
LinearAlgebra.rmul!(A::CuVecOrMat{T},
286+
LinearAlgebra.rmul!(A::StridedCuVecOrMat{T},
278287
adjB::Adjoint{<:Any,<:CuQRPackedQ{T}}) where {T<:BlasComplex} =
279288
ormqr!('R', 'C', parent(adjB).factors, parent(adjB).τ, A)
280-
LinearAlgebra.rmul!(A::CuVecOrMat{T},
289+
LinearAlgebra.rmul!(A::StridedCuVecOrMat{T},
281290
trA::Transpose{<:Any,<:CuQRPackedQ{T}}) where {T<:BlasFloat} =
282291
ormqr!('R', 'T', parent(trA).factors, parent(adjB).τ, A)
283292

@@ -300,23 +309,23 @@ end
300309
LinearAlgebra.det(Q::CuQRPackedQ{<:Real}) = isodd(count(!iszero, Q.τ)) ? -1 : 1
301310
LinearAlgebra.det(Q::CuQRPackedQ) = prod-> iszero(τ) ? one(τ) : -sign(τ)^2, Q.τ)
302311

303-
function LinearAlgebra.ldiv!(_qr::CuQR, b::CuVector)
312+
function LinearAlgebra.ldiv!(_qr::CuQR, b::StridedCuVector)
304313
m,n = size(_qr)
305314
_x = UpperTriangular(_qr.R[1:min(m,n), 1:n]) \ ((_qr.Q' * b)[1:n])
306315
b[1:n] .= _x
307316
unsafe_free!(_x)
308317
return b[1:n]
309318
end
310319

311-
function LinearAlgebra.ldiv!(_qr::CuQR, B::CuMatrix)
320+
function LinearAlgebra.ldiv!(_qr::CuQR, B::StridedCuMatrix)
312321
m,n = size(_qr)
313322
_x = UpperTriangular(_qr.R[1:min(m,n), 1:n]) \ ((_qr.Q' * B)[1:n, 1:size(B, 2)])
314323
B[1:n, 1:size(B, 2)] .= _x
315324
unsafe_free!(_x)
316325
return B[1:n, 1:size(B, 2)]
317326
end
318327

319-
function LinearAlgebra.ldiv!(x::CuArray,_qr::CuQR, b::CuArray)
328+
function LinearAlgebra.ldiv!(x::StridedCuArray,_qr::CuQR, b::StridedCuArray)
320329
_x = ldiv!(_qr, b)
321330
x .= vec(_x)
322331
unsafe_free!(_x)

0 commit comments

Comments
 (0)