diff --git a/src/solver/linalg.jl b/src/solver/linalg.jl index e191bc9f..f06c4cc2 100644 --- a/src/solver/linalg.jl +++ b/src/solver/linalg.jl @@ -53,6 +53,12 @@ LinearAlgebra.lmul!(adjA::Adjoint{T,<:CuQRPackedQ{T,S}}, B::CuVecOrMat{T}) where LinearAlgebra.lmul!(trA::Transpose{T,<:CuQRPackedQ{T,S}}, B::CuVecOrMat{T}) where {T<:Number, S<:CuMatrix} = ormqr!('L', 'T', parent(trA).factors, parent(trA).τ, B) +function LinearAlgebra.ldiv!(x::CuArrays.CuArray,_qr::CuArrays.CUSOLVER.CuQR,b::CuArrays.CuArray) + _x = UpperTriangular(_qr.R) \ (_qr.Q' * reshape(b,length(b),1)) + x .= vec(_x) + CuArrays.unsafe_free!(_x) +end + function Base.getindex(A::CuQRPackedQ{T, S}, i::Integer, j::Integer) where {T, S} x = CuArrays.zeros(T, size(A, 2)) x[j] = 1 diff --git a/test/solver.jl b/test/solver.jl index 8a1b956e..9cefb248 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -131,6 +131,20 @@ k = 1 F = qr!(A) @test h_B ≈ B*Array(F.Q) end + + @testset "ldiv!" begin + A = rand(elty, n, n) + d_A = CuArray(A) + B = rand(elty, n, n) + d_B = CuArray(B) + C = rand(elty, n, n) + d_C = CuArray(C) + F = qr!(A) + d_F = qr!(d_A) + ldiv!(C,F,B) + ldiv!(d_C,d_F,d_B) + @test C ≈ Array(d_C) + end @testset "orgqr!" begin A = rand(elty,n,m)