Skip to content

Commit 0ada3f9

Browse files
authored
Fix copying Diagonal matrices (#374)
1 parent e1acd42 commit 0ada3f9

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

src/host/linalg.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
9797
end
9898

9999

100+
## diagonal
101+
102+
Base.copy(D::Diagonal{T, <:AbstractGPUArray{T, N}}) where {T, N} = Diagonal(copy(D.diag))
103+
104+
100105
## matrix multiplication
101106

102107
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, a::Number, b::Number) where {T,S,R}

test/testsuite/linalg.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ end
7979
@test collect(B) collect(A) + collect(D)
8080
end
8181

82+
@testset "copy diagonal" begin
83+
a = AT(rand(Float32, 10))
84+
D = Diagonal(a)
85+
C = copy(D)
86+
@test C isa Diagonal
87+
@test C.diag isa AT
88+
@test collect(D) == collect(C)
89+
end
90+
8291
@testset "$f! with diagonal $d" for (f, f!) in ((triu, triu!), (tril, tril!)),
8392
d in -2:2
8493
A = randn(10, 10)

0 commit comments

Comments
 (0)