Skip to content

Commit caff08e

Browse files
authored
Implement \ and ldiv! for Diagonal (#391)
1 parent 9357154 commit caff08e

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

src/host/linalg.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,33 @@ else
127127
end
128128
end
129129

130+
function Base.:\(D::Diagonal{<:Any, <:AbstractGPUArray}, B::AbstractGPUVecOrMat)
131+
z = D.diag .== 0
132+
if any(z)
133+
i = findfirst(collect(z))
134+
throw(SingularException(i))
135+
else
136+
return D.diag .\ B
137+
end
138+
end
139+
140+
function LinearAlgebra.ldiv!(D::Diagonal{<:Any, <:AbstractGPUArray}, B::StridedVecOrMat)
141+
m, n = size(B, 1), size(B, 2)
142+
if m != length(D.diag)
143+
throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $m rows"))
144+
end
145+
(m == 0 || n == 0) && return B
146+
z = D.diag .== 0
147+
if any(z)
148+
i = findfirst(collect(z))
149+
throw(SingularException(i))
150+
else
151+
B .= D.diag .\ B
152+
end
153+
return B
154+
end
155+
156+
130157
## matrix multiplication
131158

132159
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, a::Number, b::Number) where {T,S,R}

test/testsuite/linalg.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,51 @@
130130
@test cholesky(D, check = false).info == 3
131131
end
132132

133+
@testset "\\ + Diagonal" begin
134+
n = 128
135+
d = AT(rand(Float32, n))
136+
D = Diagonal(d)
137+
b = AT(rand(Float32, n))
138+
B = AT(rand(Float32, n, n))
139+
@test collect(D \ b) Diagonal(collect(d)) \ collect(b)
140+
@test collect(D \ B) Diagonal(collect(d)) \ collect(B)
141+
142+
d = ones(Float32, n)
143+
d[rand(1:n)] = 0
144+
d = AT(d)
145+
D = Diagonal(d)
146+
@test_throws SingularException D \ B
147+
end
148+
149+
@testset "ldiv! + Diagonal" begin
150+
n = 128
151+
d = AT(rand(Float32, n))
152+
D = Diagonal(d)
153+
b = AT(rand(Float32, n))
154+
B = AT(rand(Float32, n, n))
155+
X = AT(zeros(Float32, n, n))
156+
Y = zeros(Float32, n, n)
157+
ldiv!(X, D, B)
158+
ldiv!(Y, Diagonal(collect(d)), collect(B))
159+
@test collect(X) Y
160+
ldiv!(D, B)
161+
@test collect(B) collect(X)
162+
163+
d = ones(Float32, n)
164+
d[rand(1:n)] = 0
165+
d = AT(d)
166+
D = Diagonal(d)
167+
B = AT(rand(Float32, n, n))
168+
169+
# three-argument version does not throw SingularException
170+
ldiv!(X, D, B)
171+
ldiv!(Y, Diagonal(collect(d)), collect(B))
172+
@test collect(X) Y
173+
174+
# two-argument version throws SingularException
175+
@test_throws SingularException ldiv!(D, B)
176+
end
177+
133178
@testset "$f! with diagonal $d" for (f, f!) in ((triu, triu!), (tril, tril!)),
134179
d in -2:2
135180
A = randn(Float32, 10, 10)

0 commit comments

Comments
 (0)