Skip to content

Commit 3ef1f61

Browse files
jishnubdkarrasch
andauthored
Define promote_rule for Diagonal matrices (#42142)
Co-authored-by: Daniel Karrasch <daniel.karrasch@posteo.de>
1 parent 17f05b5 commit 3ef1f61

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,17 @@ struct Diagonal{T,V<:AbstractVector{T}} <: AbstractMatrix{T}
1010
new{T,V}(diag)
1111
end
1212
end
13+
Diagonal{T,V}(d::Diagonal) where {T,V<:AbstractVector{T}} = Diagonal{T,V}(d.diag)
1314
Diagonal(v::AbstractVector{T}) where {T} = Diagonal{T,typeof(v)}(v)
1415
Diagonal{T}(v::AbstractVector) where {T} = Diagonal(convert(AbstractVector{T}, v)::AbstractVector{T})
1516

17+
function Base.promote_rule(A::Type{<:Diagonal{<:Any,V}}, B::Type{<:Diagonal{<:Any,W}}) where {V,W}
18+
X = promote_type(V, W)
19+
T = eltype(X)
20+
isconcretetype(T) && return Diagonal{T,X}
21+
return typejoin(A, B)
22+
end
23+
1624
"""
1725
Diagonal(V::AbstractVector)
1826
@@ -88,7 +96,7 @@ similar(D::Diagonal, ::Type{T}) where {T} = Diagonal(similar(D.diag, T))
8896

8997
copyto!(D1::Diagonal, D2::Diagonal) = (copyto!(D1.diag, D2.diag); D1)
9098

91-
size(D::Diagonal) = (length(D.diag),length(D.diag))
99+
size(D::Diagonal) = (n = length(D.diag); (n,n))
92100

93101
function size(D::Diagonal,d::Integer)
94102
if d<1

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,31 @@ end
884884
@test \(x, B) == /(B, x)
885885
end
886886

887+
@testset "promotion" begin
888+
for (v1, v2) in (([true], [1]), ([zeros(2,2)], [zeros(Int, 2,2)]))
889+
T = promote_type(eltype(v1), eltype(v2))
890+
V = promote_type(typeof(v1), typeof(v2))
891+
d1 = Diagonal(v1)
892+
d2 = Diagonal(v2)
893+
v = [d1, d2]
894+
@test (@inferred eltype(v)) == Diagonal{T, V}
895+
end
896+
# test for a type for which promote_type doesn't lead to a concrete eltype
897+
struct MyArrayWrapper{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
898+
a :: A
899+
end
900+
Base.size(M::MyArrayWrapper) = size(M.a)
901+
Base.axes(M::MyArrayWrapper) = axes(M.a)
902+
Base.length(M::MyArrayWrapper) = length(M.a)
903+
Base.getindex(M::MyArrayWrapper, i::Int...) = M.a[i...]
904+
Base.setindex!(M::MyArrayWrapper, v, i::Int...) = M.a[i...] = v
905+
d1 = Diagonal(MyArrayWrapper(1:3))
906+
d2 = Diagonal(MyArrayWrapper(1.0:3.0))
907+
c = [d1, d2]
908+
@test c[1] == d1
909+
@test c[2] == d2
910+
end
911+
887912
@testset "zero and one" begin
888913
D1 = Diagonal(rand(3))
889914
@test D1 + zero(D1) == D1

0 commit comments

Comments
 (0)