Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 0922b88

Browse files
Merge #580
580: Upstream ldiv! overload r=maleadt a=ChrisRackauckas DiffEqBase.jl has been carrying an ldiv! overload to make it work for awhile (https://github.com/JuliaDiffEq/DiffEqBase.jl/blob/master/src/init.jl#L148-L152), and I think it might be a good time to upstream it. Co-authored-by: Christopher Rackauckas <accounts@chrisrackauckas.com>
2 parents 2fe60ca + bad8412 commit 0922b88

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/solver/linalg.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ LinearAlgebra.lmul!(adjA::Adjoint{T,<:CuQRPackedQ{T,S}}, B::CuVecOrMat{T}) where
5353
LinearAlgebra.lmul!(trA::Transpose{T,<:CuQRPackedQ{T,S}}, B::CuVecOrMat{T}) where {T<:Number, S<:CuMatrix} =
5454
ormqr!('L', 'T', parent(trA).factors, parent(trA).τ, B)
5555

56+
function LinearAlgebra.ldiv!(x::CuArrays.CuArray,_qr::CuArrays.CUSOLVER.CuQR,b::CuArrays.CuArray)
57+
_x = UpperTriangular(_qr.R) \ (_qr.Q' * reshape(b,length(b),1))
58+
x .= vec(_x)
59+
CuArrays.unsafe_free!(_x)
60+
end
61+
5662
function Base.getindex(A::CuQRPackedQ{T, S}, i::Integer, j::Integer) where {T, S}
5763
x = CuArrays.zeros(T, size(A, 2))
5864
x[j] = 1

test/solver.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,20 @@ k = 1
131131
F = qr!(A)
132132
@test h_B B*Array(F.Q)
133133
end
134+
135+
@testset "ldiv!" begin
136+
A = rand(elty, n, n)
137+
d_A = CuArray(A)
138+
B = rand(elty, n, n)
139+
d_B = CuArray(B)
140+
C = rand(elty, n, n)
141+
d_C = CuArray(C)
142+
F = qr!(A)
143+
d_F = qr!(d_A)
144+
ldiv!(C,F,B)
145+
ldiv!(d_C,d_F,d_B)
146+
@test C Array(d_C)
147+
end
134148

135149
@testset "orgqr!" begin
136150
A = rand(elty,n,m)

0 commit comments

Comments
 (0)