@@ -20,12 +20,15 @@ _copywitheltype(::Type{T}, As...) where {T} = map(A -> copyto!(similar(A, T), A)
20
20
21
21
# matrix division
22
22
23
- const CuMatOrAdj{T} = Union{CuMatrix,
24
- LinearAlgebra. Adjoint{T, <: CuMatrix{T} },
25
- LinearAlgebra. Transpose{T, <: CuMatrix{T} }}
26
- const CuOrAdj{T} = Union{CuVecOrMat,
27
- LinearAlgebra. Adjoint{T, <: CuVecOrMat{T} },
28
- LinearAlgebra. Transpose{T, <: CuVecOrMat{T} }}
23
+ const CuMatOrAdj{T} = Union{StridedCuMatrix,
24
+ LinearAlgebra. Adjoint{T, <: StridedCuMatrix{T} },
25
+ LinearAlgebra. Transpose{T, <: StridedCuMatrix{T} }}
26
+ const CuOrAdj{T} = Union{StridedCuVector,
27
+ LinearAlgebra. Adjoint{T, <: StridedCuVector{T} },
28
+ LinearAlgebra. Transpose{T, <: StridedCuVector{T} },
29
+ StridedCuMatrix,
30
+ LinearAlgebra. Adjoint{T, <: StridedCuMatrix{T} },
31
+ LinearAlgebra. Transpose{T, <: StridedCuMatrix{T} }}
29
32
30
33
function Base.:\ (_A:: CuMatOrAdj , _B:: CuOrAdj )
31
34
A, B = copy_cublasfloat (_A, _B)
@@ -101,31 +104,34 @@ using LinearAlgebra: Factorization, AbstractQ, QRCompactWY, QRCompactWYQ, QRPack
101
104
102
105
if VERSION >= v " 1.8-"
103
106
107
+
108
+
104
109
LinearAlgebra. qr! (A:: StridedCuMatrix{T} ) where T = QR (geqrf! (A:: StridedCuMatrix{T} )... )
105
110
111
+
106
112
# conversions
107
113
CuMatrix (F:: Union{QR,QRCompactWY} ) = CuArray (AbstractArray (F))
108
114
CuArray (F:: Union{QR,QRCompactWY} ) = CuMatrix (F)
109
115
CuMatrix (F:: QRPivoted ) = CuArray (AbstractArray (F))
110
116
CuArray (F:: QRPivoted ) = CuMatrix (F)
111
117
112
- function LinearAlgebra. ldiv! (_qr:: QR , b:: CuVector )
118
+ function LinearAlgebra. ldiv! (_qr:: QR , b:: StridedCuVector )
113
119
m,n = size (_qr)
114
120
_x = UpperTriangular (_qr. R[1 : min (m,n), 1 : n]) \ ((_qr. Q' * b)[1 : n])
115
121
b[1 : n] .= _x
116
122
unsafe_free! (_x)
117
123
return b[1 : n]
118
124
end
119
125
120
- function LinearAlgebra. ldiv! (_qr:: QR , B:: CuMatrix )
126
+ function LinearAlgebra. ldiv! (_qr:: QR , B:: StridedCuMatrix )
121
127
m,n = size (_qr)
122
128
_x = UpperTriangular (_qr. R[1 : min (m,n), 1 : n]) \ ((_qr. Q' * B)[1 : n, 1 : size (B, 2 )])
123
129
B[1 : n, 1 : size (B, 2 )] .= _x
124
130
unsafe_free! (_x)
125
131
return B[1 : n, 1 : size (B, 2 )]
126
132
end
127
133
128
- function LinearAlgebra. ldiv! (x:: CuArray , _qr:: QR , b:: CuArray )
134
+ function LinearAlgebra. ldiv! (x:: StridedCuArray , _qr:: QR , b:: StridedCuArray )
129
135
_x = ldiv! (_qr, b)
130
136
x .= vec (_x)
131
137
unsafe_free! (_x)
@@ -146,71 +152,74 @@ CuMatrix{T}(Q::QRCompactWYQ) where {T} = error("QRCompactWY format is not suppor
146
152
Matrix {T} (Q:: QRPackedQ{S,<:CuArray,<:CuArray} ) where {T,S} = Array (CuMatrix {T} (Q))
147
153
Matrix {T} (Q:: QRCompactWYQ{S,<:CuArray,<:CuArray} ) where {T,S} = Array (CuMatrix {T} (Q))
148
154
155
+
156
+
149
157
# extracting the full matrix can be done with `collect` (which defaults to `Array`)
150
- function Base. collect (src:: Union {QRPackedQ{<: Any ,<: CuArray ,<: CuArray },
151
- QRCompactWYQ{<: Any ,<: CuArray ,<: CuArray }})
158
+ function Base. collect (src:: Union {QRPackedQ{<: Any ,<: StridedCuArray ,<: StridedCuArray },
159
+ QRCompactWYQ{<: Any ,<: StridedCuArray ,<: StridedCuArray }})
152
160
dest = similar (src)
153
161
copyto! (dest, I)
154
162
lmul! (src, dest)
155
163
collect (dest)
156
164
end
157
165
158
166
# avoid the generic similar fallback that returns a CPU array
159
- Base. similar (Q:: Union {QRPackedQ{<: Any ,<: CuArray ,<: CuArray },
160
- QRCompactWYQ{<: Any ,<: CuArray ,<: CuArray }},
167
+ Base. similar (Q:: Union {QRPackedQ{<: Any ,<: StridedCuArray ,<: StridedCuArray },
168
+ QRCompactWYQ{<: Any ,<: StridedCuArray ,<: StridedCuArray }},
161
169
:: Type{T} , dims:: Dims{N} ) where {T,N} =
162
170
CuArray {T,N} (undef, dims)
163
171
164
- function Base. getindex (Q:: QRPackedQ{<:Any, <:CuArray } , :: Colon , j:: Int )
172
+ function Base. getindex (Q:: QRPackedQ{<:Any, <:StridedCuArray } , :: Colon , j:: Int )
165
173
y = CUDA. zeros (eltype (Q), size (Q, 2 ))
166
174
y[j] = 1
167
175
lmul! (Q, y)
168
176
end
169
177
178
+
170
179
# multiplication by Q
171
- LinearAlgebra. lmul! (A:: QRPackedQ{T,<:CuArray ,<:CuArray } ,
180
+ LinearAlgebra. lmul! (A:: QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray } ,
172
181
B:: CuVecOrMat{T} ) where {T<: BlasFloat } =
173
182
ormqr! (' L' , ' N' , A. factors, A. τ, B)
174
- LinearAlgebra. lmul! (adjA:: Adjoint{T,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ,
183
+ LinearAlgebra. lmul! (adjA:: Adjoint{T,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ,
175
184
B:: CuVecOrMat{T} ) where {T<: BlasReal } =
176
185
ormqr! (' L' , ' T' , parent (adjA). factors, parent (adjA). τ, B)
177
- LinearAlgebra. lmul! (adjA:: Adjoint{T,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ,
186
+ LinearAlgebra. lmul! (adjA:: Adjoint{T,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ,
178
187
B:: CuVecOrMat{T} ) where {T<: BlasComplex } =
179
188
ormqr! (' L' , ' C' , parent (adjA). factors, parent (adjA). τ, B)
180
- LinearAlgebra. lmul! (trA:: Transpose{T,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ,
189
+ LinearAlgebra. lmul! (trA:: Transpose{T,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ,
181
190
B:: CuVecOrMat{T} ) where {T<: BlasFloat } =
182
191
ormqr! (' L' , ' T' , parent (trA). factors, parent (trA). τ, B)
183
192
184
193
LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
185
- B:: QRPackedQ{T,<:CuArray ,<:CuArray } ) where {T<: BlasFloat } =
194
+ B:: QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray } ) where {T<: BlasFloat } =
186
195
ormqr! (' R' , ' N' , B. factors, B. τ, A)
187
196
LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
188
- adjB:: Adjoint{<:Any,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ) where {T<: BlasReal } =
197
+ adjB:: Adjoint{<:Any,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ) where {T<: BlasReal } =
189
198
ormqr! (' R' , ' T' , parent (adjB). factors, parent (adjB). τ, A)
190
199
LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
191
- adjB:: Adjoint{<:Any,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ) where {T<: BlasComplex } =
200
+ adjB:: Adjoint{<:Any,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ) where {T<: BlasComplex } =
192
201
ormqr! (' R' , ' C' , parent (adjB). factors, parent (adjB). τ, A)
193
202
LinearAlgebra. rmul! (A:: CuVecOrMat{T} ,
194
- trA:: Transpose{<:Any,<:QRPackedQ{T,<:CuArray ,<:CuArray }} ) where {T<: BlasFloat } =
203
+ trA:: Transpose{<:Any,<:QRPackedQ{T,<:StridedCuArray ,<:StridedCuArray }} ) where {T<: BlasFloat } =
195
204
ormqr! (' R' , ' T' , parent (trA). factors, parent (adjB). τ, A)
196
205
197
206
else
198
207
199
208
struct CuQR{T} <: Factorization{T}
200
- factors:: CuMatrix
201
- τ:: CuVector {T}
202
- CuQR {T} (factors:: CuMatrix {T} , τ:: CuVector {T} ) where {T} = new (factors, τ)
209
+ factors:: StridedCuMatrix
210
+ τ:: StridedCuVector {T}
211
+ CuQR {T} (factors:: StridedCuMatrix {T} , τ:: StridedCuVector {T} ) where {T} = new (factors, τ)
203
212
end
204
213
205
214
struct CuQRPackedQ{T} <: AbstractQ{T}
206
- factors:: CuMatrix {T}
207
- τ:: CuVector {T}
208
- CuQRPackedQ {T} (factors:: CuMatrix {T} , τ:: CuVector {T} ) where {T} = new (factors, τ)
215
+ factors:: StridedCuMatrix {T}
216
+ τ:: StridedCuVector {T}
217
+ CuQRPackedQ {T} (factors:: StridedCuMatrix {T} , τ:: StridedCuVector {T} ) where {T} = new (factors, τ)
209
218
end
210
219
211
- CuQR (factors:: CuMatrix {T} , τ:: CuVector {T} ) where {T} =
220
+ CuQR (factors:: StridedCuMatrix {T} , τ:: StridedCuVector {T} ) where {T} =
212
221
CuQR {T} (factors, τ)
213
- CuQRPackedQ (factors:: CuMatrix {T} , τ:: CuVector {T} ) where {T} =
222
+ CuQRPackedQ (factors:: StridedCuMatrix {T} , τ:: StridedCuVector {T} ) where {T} =
214
223
CuQRPackedQ {T} (factors, τ)
215
224
216
225
# AbstractQ's `size` is the size of the full matrix,
@@ -245,7 +254,7 @@ Base.Matrix(A::CuQRPackedQ) = Matrix(CuMatrix(A))
245
254
function Base. getproperty (A:: CuQR , d:: Symbol )
246
255
m, n = size (getfield (A, :factors ))
247
256
if d == :R
248
- return triu! (A. factors[ 1 : min (m, n), 1 : n] )
257
+ return triu! (view ( A. factors, 1 : min (m, n), 1 : n) )
249
258
elseif d == :Q
250
259
return CuQRPackedQ (A. factors, A. τ)
251
260
else
@@ -259,25 +268,25 @@ Base.iterate(S::CuQR, ::Val{:R}) = (S.R, Val(:done))
259
268
Base. iterate (S:: CuQR , :: Val{:done} ) = nothing
260
269
261
270
# Apply changes Q from the left
262
- LinearAlgebra. lmul! (A:: CuQRPackedQ{T} , B:: CuVecOrMat {T} ) where {T<: BlasFloat } =
271
+ LinearAlgebra. lmul! (A:: CuQRPackedQ{T} , B:: StridedCuVecOrMat {T} ) where {T<: BlasFloat } =
263
272
ormqr! (' L' , ' N' , A. factors, A. τ, B)
264
- LinearAlgebra. lmul! (adjA:: Adjoint{T,<:CuQRPackedQ{T}} , B:: CuVecOrMat {T} ) where {T<: BlasReal } =
273
+ LinearAlgebra. lmul! (adjA:: Adjoint{T,<:CuQRPackedQ{T}} , B:: StridedCuVecOrMat {T} ) where {T<: BlasReal } =
265
274
ormqr! (' L' , ' T' , parent (adjA). factors, parent (adjA). τ, B)
266
- LinearAlgebra. lmul! (adjA:: Adjoint{T,<:CuQRPackedQ{T}} , B:: CuVecOrMat {T} ) where {T<: BlasComplex } =
275
+ LinearAlgebra. lmul! (adjA:: Adjoint{T,<:CuQRPackedQ{T}} , B:: StridedCuVecOrMat {T} ) where {T<: BlasComplex } =
267
276
ormqr! (' L' , ' C' , parent (adjA). factors, parent (adjA). τ, B)
268
- LinearAlgebra. lmul! (trA:: Transpose{T,<:CuQRPackedQ{T}} , B:: CuVecOrMat {T} ) where {T<: BlasFloat } =
277
+ LinearAlgebra. lmul! (trA:: Transpose{T,<:CuQRPackedQ{T}} , B:: StridedCuVecOrMat {T} ) where {T<: BlasFloat } =
269
278
ormqr! (' L' , ' T' , parent (trA). factors, parent (trA). τ, B)
270
279
271
280
# Apply changes Q from the right
272
- LinearAlgebra. rmul! (A:: CuVecOrMat {T} , B:: CuQRPackedQ{T} ) where {T<: BlasFloat } =
281
+ LinearAlgebra. rmul! (A:: StridedCuVecOrMat {T} , B:: CuQRPackedQ{T} ) where {T<: BlasFloat } =
273
282
ormqr! (' R' , ' N' , B. factors, B. τ, A)
274
- LinearAlgebra. rmul! (A:: CuVecOrMat {T} ,
283
+ LinearAlgebra. rmul! (A:: StridedCuVecOrMat {T} ,
275
284
adjB:: Adjoint{<:Any,<:CuQRPackedQ{T}} ) where {T<: BlasReal } =
276
285
ormqr! (' R' , ' T' , parent (adjB). factors, parent (adjB). τ, A)
277
- LinearAlgebra. rmul! (A:: CuVecOrMat {T} ,
286
+ LinearAlgebra. rmul! (A:: StridedCuVecOrMat {T} ,
278
287
adjB:: Adjoint{<:Any,<:CuQRPackedQ{T}} ) where {T<: BlasComplex } =
279
288
ormqr! (' R' , ' C' , parent (adjB). factors, parent (adjB). τ, A)
280
- LinearAlgebra. rmul! (A:: CuVecOrMat {T} ,
289
+ LinearAlgebra. rmul! (A:: StridedCuVecOrMat {T} ,
281
290
trA:: Transpose{<:Any,<:CuQRPackedQ{T}} ) where {T<: BlasFloat } =
282
291
ormqr! (' R' , ' T' , parent (trA). factors, parent (adjB). τ, A)
283
292
@@ -300,23 +309,23 @@ end
300
309
LinearAlgebra. det (Q:: CuQRPackedQ{<:Real} ) = isodd (count (! iszero, Q. τ)) ? - 1 : 1
301
310
LinearAlgebra. det (Q:: CuQRPackedQ ) = prod (τ -> iszero (τ) ? one (τ) : - sign (τ)^ 2 , Q. τ)
302
311
303
- function LinearAlgebra. ldiv! (_qr:: CuQR , b:: CuVector )
312
+ function LinearAlgebra. ldiv! (_qr:: CuQR , b:: StridedCuVector )
304
313
m,n = size (_qr)
305
314
_x = UpperTriangular (_qr. R[1 : min (m,n), 1 : n]) \ ((_qr. Q' * b)[1 : n])
306
315
b[1 : n] .= _x
307
316
unsafe_free! (_x)
308
317
return b[1 : n]
309
318
end
310
319
311
- function LinearAlgebra. ldiv! (_qr:: CuQR , B:: CuMatrix )
320
+ function LinearAlgebra. ldiv! (_qr:: CuQR , B:: StridedCuMatrix )
312
321
m,n = size (_qr)
313
322
_x = UpperTriangular (_qr. R[1 : min (m,n), 1 : n]) \ ((_qr. Q' * B)[1 : n, 1 : size (B, 2 )])
314
323
B[1 : n, 1 : size (B, 2 )] .= _x
315
324
unsafe_free! (_x)
316
325
return B[1 : n, 1 : size (B, 2 )]
317
326
end
318
327
319
- function LinearAlgebra. ldiv! (x:: CuArray ,_qr:: CuQR , b:: CuArray )
328
+ function LinearAlgebra. ldiv! (x:: StridedCuArray ,_qr:: CuQR , b:: StridedCuArray )
320
329
_x = ldiv! (_qr, b)
321
330
x .= vec (_x)
322
331
unsafe_free! (_x)
0 commit comments