From 0e0dc5d2431d4ffc58656814e2d629c150d66fa5 Mon Sep 17 00:00:00 2001 From: Ralph Kube Date: Thu, 26 Aug 2021 11:08:41 -0400 Subject: [PATCH] Fixed reducedim operations for CuQRPackedQ --- lib/cusolver/linalg.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/cusolver/linalg.jl b/lib/cusolver/linalg.jl index 31480337ec..0427cf084b 100644 --- a/lib/cusolver/linalg.jl +++ b/lib/cusolver/linalg.jl @@ -26,6 +26,14 @@ Base.size(A::CuQRPackedQ, dim::Integer) = 0 < dim ? (dim <= 2 ? size(A.factors, CUDA.CuMatrix(A::CuQRPackedQ) = orgqr!(copy(A.factors), A.τ) CUDA.CuArray(A::CuQRPackedQ) = CuMatrix(A) Base.Matrix(A::CuQRPackedQ) = Matrix(CuMatrix(A)) +Base.similar(A::CuQRPackedQ{S,T}) where {S,T} = CuArray{S, ndims(a)}(undef, size(a)) + +# CuQRPackedQ isn't an AbstractGPUArray, so we miss out on lots of GPU-compatible methods. +# add a couple common methods here that avoid GPU-incompatible fallbacks. +Base.sum(xs::CuQRPackedQ) = sum(CuArray(xs)) +Base.prod(xs::CuQRPackedQ) = prod(CuArray(xs)) +Base.maximum(xs::CuQRPackedQ) = maximum(CuArray(xs)) +Base.minimum(xs::CuQRPackedQ) = minimum(CuArray(xs)) function Base.getproperty(A::CuQR, d::Symbol) m, n = size(getfield(A, :factors))