From eb2fb9c3ba473b594fc36b2f764ecc5b74879dc1 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 31 Jan 2020 10:32:50 -0500 Subject: [PATCH 1/2] Upstream ldiv! overload DiffEqBase.jl has been carrying an ldiv! overload to make it work for awhile (https://github.com/JuliaDiffEq/DiffEqBase.jl/blob/master/src/init.jl#L148-L152), and I think it might be a good time to upstream it. --- src/solver/linalg.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/solver/linalg.jl b/src/solver/linalg.jl index e191bc9f..f06c4cc2 100644 --- a/src/solver/linalg.jl +++ b/src/solver/linalg.jl @@ -53,6 +53,12 @@ LinearAlgebra.lmul!(adjA::Adjoint{T,<:CuQRPackedQ{T,S}}, B::CuVecOrMat{T}) where LinearAlgebra.lmul!(trA::Transpose{T,<:CuQRPackedQ{T,S}}, B::CuVecOrMat{T}) where {T<:Number, S<:CuMatrix} = ormqr!('L', 'T', parent(trA).factors, parent(trA).τ, B) +function LinearAlgebra.ldiv!(x::CuArrays.CuArray,_qr::CuArrays.CUSOLVER.CuQR,b::CuArrays.CuArray) + _x = UpperTriangular(_qr.R) \ (_qr.Q' * reshape(b,length(b),1)) + x .= vec(_x) + CuArrays.unsafe_free!(_x) +end + function Base.getindex(A::CuQRPackedQ{T, S}, i::Integer, j::Integer) where {T, S} x = CuArrays.zeros(T, size(A, 2)) x[j] = 1 From bad8412d4f1942f5cf74753eb7169a07b86b742d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 31 Jan 2020 21:46:09 -0500 Subject: [PATCH 2/2] add ldiv! test --- test/solver.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/solver.jl b/test/solver.jl index 8a1b956e..9cefb248 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -131,6 +131,20 @@ k = 1 F = qr!(A) @test h_B ≈ B*Array(F.Q) end + + @testset "ldiv!" begin + A = rand(elty, n, n) + d_A = CuArray(A) + B = rand(elty, n, n) + d_B = CuArray(B) + C = rand(elty, n, n) + d_C = CuArray(C) + F = qr!(A) + d_F = qr!(d_A) + ldiv!(C,F,B) + ldiv!(d_C,d_F,d_B) + @test C ≈ Array(d_C) + end @testset "orgqr!" begin A = rand(elty,n,m)