Skip to content

Commit 4f47345

Browse files
authored
Various fixes for diagm (#848)
* Implement the trivial diagm(v) for zeroth diagonal * Make diagm use similar_type for the output * Avoid calling problematic generic functions within the generator body (including promote_type)
1 parent 970ce8b commit 4f47345

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

src/linalg.jl

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -134,22 +134,26 @@ end
134134
end
135135

136136
@generated function diagm(kvs::Pair{<:Val,<:StaticVector}...)
137-
N = maximum(abs(kv.parameters[1].parameters[1]) + length(kv.parameters[2]) for kv in kvs)
138-
X = [Symbol("x_$(i)_$(j)") for i in 1:N, j in 1:N]
139-
T = promote_type((eltype(kv.parameters[2]) for kv in kvs)...)
140-
exprs = fill(:(zero($T)), N*N)
141-
for m in eachindex(kvs)
142-
kv = kvs[m]
143-
ind = diagind(N, N, kv.parameters[1].parameters[1])
144-
for n = 1:length(kv.parameters[2])
145-
exprs[ind[n]] = :(kvs[$m].second[$n])
137+
diag_ind_and_length = [(kv.parameters[1].parameters[1], length(kv.parameters[2])) for kv in kvs]
138+
N = maximum(abs(di) + dl for (di,dl) in diag_ind_and_length)
139+
vs = [Symbol("v$i") for i=1:length(kvs)]
140+
vs_exprs = [:(@inbounds $(vs[i]) = kvs[$i].second) for i=eachindex(kvs)]
141+
element_exprs = Any[false for _=1:N*N]
142+
for (i, (di, dl)) in enumerate(diag_ind_and_length)
143+
diaginds = diagind(N, N, di)
144+
for n = 1:dl
145+
element_exprs[diaginds[n]] = :($(vs_exprs[i])[$n])
146146
end
147147
end
148148
return quote
149149
$(Expr(:meta, :inline))
150-
@inbounds return SMatrix{$N,$N,$T}(tuple($(exprs...)))
150+
$(vs_exprs...)
151+
@inbounds elements = tuple($(element_exprs...))
152+
T = promote_tuple_eltype(elements)
153+
@inbounds return similar_type(v1, T, Size($N,$N))(elements)
151154
end
152155
end
156+
@inline diagm(v::StaticVector) = diagm(Val(0)=>v)
153157

154158
@inline diag(m::StaticMatrix, k::Type{Val{D}}=Val{0}) where {D} = _diag(Size(m), m, k)
155159
@generated function _diag(::Size{S}, m::StaticMatrix, ::Type{Val{D}}) where {S,D}

test/linalg.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,17 @@ StaticArrays.similar_type(::Union{RotMat2,Type{RotMat2}}) = SMatrix{2,2,Float64,
8686
end
8787

8888
@testset "diagm()" begin
89+
@test @inferred(diagm(SA[1,2])) === SA[1 0; 0 2]
8990
@test @inferred(diagm(Val(0) => SVector(1,2))) === @SMatrix [1 0; 0 2]
9091
@test @inferred(diagm(Val(2) => SVector(1,2,3)))::SMatrix == diagm(2 => [1,2,3])
9192
@test @inferred(diagm(Val(-2) => SVector(1,2,3)))::SMatrix == diagm(-2 => [1,2,3])
92-
@test @inferred(diagm(Val(-2) => SVector(1,2,3), Val(1) => SVector(4,5)))::SMatrix == diagm(-2 => [1,2,3], 1 => [4,5])
93+
@test @inferred(diagm(Val(-2) => SVector(1,2,3), Val(1) => SVector(4,5)))::SMatrix ==
94+
diagm(-2 => [1,2,3], 1 => [4,5])
95+
# numeric promotion
96+
@test @inferred(diagm(Val(0) => SA[1,2,3], Val(1) => SA[4.0im,5.0im]))::SMatrix{3,3,ComplexF64} ==
97+
diagm(0 => [1.0,2.0,3.0], 1 => [4.0im,5.0im])
98+
# diagm respects input type
99+
@test @inferred(diagm(MArray(SA[1,2])))::MArray == SA[1 0; 0 2]
93100
end
94101

95102
@testset "diag()" begin

0 commit comments

Comments
 (0)