Skip to content

Commit 39eccd8

Browse files
wsshinandyferris
authored andcommitted
Implement diag() with Val as second argument (#226)
1 parent 0cbb281 commit 39eccd8

File tree

3 files changed

+20
-1
lines changed

3 files changed

+20
-1
lines changed

src/StaticArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import Base: getindex, setindex!, size, similar, vec, show,
88
length, convert, promote_op, promote_rule, map, map!, reduce, reducedim, mapreducedim,
99
mapreduce, broadcast, broadcast!, conj, transpose, ctranspose,
1010
hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill,
11-
fill!, det, inv, eig, eigvals, expm, sqrtm, trace, vecnorm, norm, dot, diagm,
11+
fill!, det, inv, eig, eigvals, expm, sqrtm, trace, vecnorm, norm, dot, diagm, diag,
1212
sum, diff, prod, count, any, all, minimum,
1313
maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!,
1414
randexp!, normalize, normalize!, read, read!, write

src/linalg.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,19 @@ end
183183
end
184184
end
185185

186+
@inline diag(m::StaticMatrix, k::Type{Val{D}}=Val{0}) where {D} = _diag(Size(m), m, k)
187+
@generated function _diag(::Size{S}, m::StaticMatrix, ::Type{Val{D}}) where {S,D}
188+
S1, S2 = S
189+
rng = D 0 ? range(1-D, S1+1, min(S1+D, S2)) : range(D*S1+1, S1+1, min(S1, S2-D))
190+
Snew = length(rng)
191+
T = eltype(m)
192+
exprs = [:(m[$i]) for i = rng]
193+
return quote
194+
$(Expr(:meta, :inline))
195+
@inbounds return similar_type($m, Size($Snew))(tuple($(exprs...)))
196+
end
197+
end
198+
186199
@inline cross(a::StaticVector, b::StaticVector) = _cross(same_size(a, b), a, b)
187200
_cross(::Size{S}, a::StaticVector, b::StaticVector) where {S} = error("Cross product not defined for $(S[1])-vectors")
188201
@inline function _cross(::Size{(2,)}, a::StaticVector, b::StaticVector)

test/linalg.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@
4040
@test @inferred(diagm(SVector(1,2))) === @SMatrix [1 0; 0 2]
4141
end
4242

43+
@testset "diag()" begin
44+
@test @inferred(diag(@SMatrix([0 1; 2 3]))) === SVector(0, 3)
45+
@test @inferred(diag(@SMatrix([0 1 2; 3 4 5]), Val{1})) === SVector(1, 5)
46+
@test @inferred(diag(@SMatrix([0 1; 2 3; 4 5]), Val{-1})) === SVector(2, 5)
47+
end
48+
4349
@testset "one() and zero()" begin
4450
@test @inferred(one(SMatrix{2,2,Int})) === @SMatrix [1 0; 0 1]
4551
@test @inferred(one(SMatrix{2,2})) === @SMatrix [1.0 0.0; 0.0 1.0]

0 commit comments

Comments
 (0)