Skip to content

Commit 0e0dc5d

Browse files
Ralph Kubemaleadt
authored andcommitted
Fixed reducedim operations for CuQRPackedQ
1 parent afe8179 commit 0e0dc5d

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

lib/cusolver/linalg.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ Base.size(A::CuQRPackedQ, dim::Integer) = 0 < dim ? (dim <= 2 ? size(A.factors,
2626
CUDA.CuMatrix(A::CuQRPackedQ) = orgqr!(copy(A.factors), A.τ)
2727
CUDA.CuArray(A::CuQRPackedQ) = CuMatrix(A)
2828
Base.Matrix(A::CuQRPackedQ) = Matrix(CuMatrix(A))
29+
Base.similar(A::CuQRPackedQ{S,T}) where {S,T} = CuArray{S, ndims(a)}(undef, size(a))
30+
31+
# CuQRPackedQ isn't an AbstractGPUArray, so we miss out on lots of GPU-compatible methods.
32+
# add a couple common methods here that avoid GPU-incompatible fallbacks.
33+
Base.sum(xs::CuQRPackedQ) = sum(CuArray(xs))
34+
Base.prod(xs::CuQRPackedQ) = prod(CuArray(xs))
35+
Base.maximum(xs::CuQRPackedQ) = maximum(CuArray(xs))
36+
Base.minimum(xs::CuQRPackedQ) = minimum(CuArray(xs))
2937

3038
function Base.getproperty(A::CuQR, d::Symbol)
3139
m, n = size(getfield(A, :factors))

0 commit comments

Comments
 (0)