Skip to content

Commit 613166c

Browse files
Ralph Kubemaleadt
authored andcommitted
Fixed reducedim operations for CuQRPackedQ
1 parent afe8179 commit 613166c

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

lib/cusolver/linalg.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ CUDA.CuMatrix(A::CuQRPackedQ) = orgqr!(copy(A.factors), A.τ)
2727
CUDA.CuArray(A::CuQRPackedQ) = CuMatrix(A)
2828
Base.Matrix(A::CuQRPackedQ) = Matrix(CuMatrix(A))
2929

30+
# CuQRPackedQ isn't an AbstractGPUArray, so we miss out on lots of GPU-compatible methods.
31+
# add a couple common methods here that avoid GPU-incompatible fallbacks.
32+
Base.similar(A::CuQRPackedQ{S,T}) where {S,T} = CuArray{S, ndims(a)}(undef, size(a))
33+
Base.sum(xs::CuQRPackedQ) = sum(CuArray(xs))
34+
Base.prod(xs::CuQRPackedQ) = prod(CuArray(xs))
35+
Base.maximum(xs::CuQRPackedQ) = maximum(CuArray(xs))
36+
Base.minimum(xs::CuQRPackedQ) = minimum(CuArray(xs))
37+
3038
function Base.getproperty(A::CuQR, d::Symbol)
3139
m, n = size(getfield(A, :factors))
3240
if d == :R

0 commit comments

Comments
 (0)