Skip to content

Commit 744f64c

Browse files
committed
Static wrappers for svdvals(), svd() and svdfact()
A bunch of wrappers to ensure that the shape of static matrices is preserved when using various SVD factorization related functions.
1 parent c089124 commit 744f64c

File tree

5 files changed

+129
-1
lines changed

5 files changed

+129
-1
lines changed

src/StaticArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import Base: getindex, setindex!, size, similar, vec, show,
99
mapreduce, broadcast, broadcast!, conj, transpose, ctranspose,
1010
hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill,
1111
fill!, det, inv, eig, eigvals, expm, sqrtm, trace, vecnorm, norm, dot, diagm, diag,
12+
svd, svdvals!, svdfact,
1213
sum, diff, prod, count, any, all, minimum,
1314
maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!,
1415
randexp!, normalize, normalize!, read, read!, write

src/linalg.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,57 @@ end
323323

324324
@inline Base.LinAlg.Symmetric(A::StaticMatrix, uplo::Char='U') = (Base.LinAlg.checksquare(A);Symmetric{eltype(A),typeof(A)}(A, uplo))
325325
@inline Base.LinAlg.Hermitian(A::StaticMatrix, uplo::Char='U') = (Base.LinAlg.checksquare(A);Hermitian{eltype(A),typeof(A)}(A, uplo))
326+
327+
328+
#-------------------------------------------------------------------------------
329+
# Wrappers for various matrix decompositions which correctly propagate the
330+
# StaticMatrix type in the returned factorization.
331+
332+
333+
#--------------------------------------------------
334+
# lu
335+
# TODO
336+
337+
338+
#--------------------------------------------------
339+
# qr
340+
# TODO
341+
342+
#--------------------------------------------------
343+
# SVD
344+
# We need our own SVD factorization struct, as Base.LinAlg.SVD assumes
345+
# Base.Vector for `S`, and that the `U` and `Vt` have the same
346+
struct SVD{T,TU,TS,TVt} <: Factorization{T}
347+
U::TU
348+
S::TS
349+
Vt::TVt
350+
end
351+
SVD(U::AbstractArray{T}, S::AbstractVector, Vt::AbstractArray{T}) where {T} = SVD{T,typeof(U),typeof(S),typeof(Vt)}(U, S, Vt)
352+
353+
getindex(::SVD, ::Symbol) = error("In order to avoid type instability, StaticArrays.SVD doesn't support indexing the output of svdfact with a symbol. Instead, you can access the fields of the factorization directly as f.U, f.S, and f.Vt")
354+
355+
# Return the "diagonal size" of a matrix - the minimum of the two dimensions
356+
@generated function diagsize(A::StaticMatrix{N,M}) where {N,M}
357+
:($(min(N,M)))
358+
end
359+
360+
function svdvals!(A::StaticMatrix)
361+
sv = svdvals!(Matrix(A))
362+
similar_type(A, eltype(sv), Size(diagsize(A)))(sv)
363+
end
364+
365+
function svdfact(A::StaticMatrix)
366+
# "Thin" SVD only for now.
367+
f = svdfact(Matrix(A))
368+
U = similar_type(A, eltype(f.U), Size(Size(A)[1], diagsize(A)))(f.U)
369+
S = similar_type(A, Size(diagsize(A)))(f.S)
370+
Vt = similar_type(A, eltype(f.Vt), Size(diagsize(A), Size(A)[2]))(f.Vt)
371+
SVD(U,S,Vt)
372+
end
373+
374+
function svd(A::StaticMatrix)
375+
# Need our own version of `svd()`, as `Base` passes the `thin` argument
376+
# which makes the resulting dimensions uninferrable.
377+
f = svdfact(A)
378+
(f.U, f.S, f.Vt')
379+
end

test/linalg.jl

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using StaticArrays, Base.Test
2+
13
@testset "Linear algebra" begin
24

35
@testset "SVector as a (mathematical) vector space" begin
@@ -138,11 +140,44 @@
138140
@test trace(@SMatrix [1.0 2.0; 3.0 4.0]) === 5.0
139141
@test_throws DimensionMismatch trace(@SMatrix rand(5,4))
140142
end
141-
143+
142144
@testset "size zero" begin
143145
@test vecdot(SVector{0, Float64}(()), SVector{0, Float64}(())) === 0.
144146
@test vecnorm(SVector{0, Float64}(())) === 0.
145147
@test vecnorm(SVector{0, Float64}(()), 1) === 0.
146148
@test trace(SMatrix{0,0,Float64}(())) === 0.
147149
end
150+
151+
@testset "svd" begin
152+
@testinf svdvals(@SMatrix [2 0; 0 0]) [2, 0]
153+
@testinf svdvals((@SMatrix [2 -2; 1 1]) / sqrt(2)) [2, 1]
154+
155+
m3 = @SMatrix Float64[3 9 4; 6 6 2; 3 7 9]
156+
@testinf svdvals(m3) svdvals(Matrix(m3))
157+
158+
@testinf svd(m3)[1] svd(Matrix(m3))[1]
159+
@testinf svd(m3)[2] svd(Matrix(m3))[2]
160+
@testinf svd(m3)[3] svd(Matrix(m3))[3]
161+
end
162+
163+
@testset "svdfact" begin
164+
@test_throws ErrorException svdfact(@SMatrix [1 0; 0 1])[:U]
165+
166+
@testinf svdfact(@SMatrix [2 0; 0 0]).U === eye(SMatrix{2,2})
167+
@testinf svdfact(@SMatrix [2 0; 0 0]).S === SVector(2, 0)
168+
@testinf svdfact(@SMatrix [2 0; 0 0]).Vt === eye(SMatrix{2,2})
169+
170+
@testinf svdfact((@SMatrix [2 -2; 1 1]) / sqrt(2)).U [-1 0; 0 1]
171+
@testinf svdfact((@SMatrix [2 -2; 1 1]) / sqrt(2)).S [2, 1]
172+
@testinf svdfact((@SMatrix [2 -2; 1 1]) / sqrt(2)).Vt [-1 1; 1 1]/sqrt(2)
173+
174+
m23 = @SMatrix Float64[3 9 4; 6 6 2]
175+
@testinf svdfact(m23).U svdfact(Matrix(m23))[:U]
176+
@testinf svdfact(m23).S svdfact(Matrix(m23))[:S]
177+
@testinf svdfact(m23).Vt svdfact(Matrix(m23))[:Vt]
178+
179+
@testinf svdfact(m23').U svdfact(Matrix(m23'))[:U]
180+
@testinf svdfact(m23').S svdfact(Matrix(m23'))[:S]
181+
@testinf svdfact(m23').Vt svdfact(Matrix(m23'))[:Vt]
182+
end
148183
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using StaticArrays
22
using Base.Test
33

4+
include("testutil.jl")
5+
46
@testset "StaticArrays" begin
57
include("SVector.jl")
68
include("MVector.jl")

test/testutil.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using Base.Test, StaticArrays
2+
3+
"""
4+
x ≊ y
5+
6+
Inexact equality comparison. Like `≈` this calls `isapprox`, but with a
7+
tighter tolerance of `rtol=10*eps()`. Input with "\\approxeq".
8+
"""
9+
(x,y) = isapprox(x, y, rtol=10*eps())
10+
11+
"""
12+
@testinf a op b
13+
14+
Test that the type of the first argument `a` is inferred, and that `a op b` is
15+
true. For example, the following are equivalent:
16+
17+
@testinf SVector(1,2) + SVector(1,2) == SVector(2,4)
18+
@test @inferred(SVector(1,2) + SVector(1,2)) == SVector(2,4)
19+
"""
20+
macro testinf(ex)
21+
@assert ex.head == :call
22+
infarg = ex.args[2]
23+
if !(infarg isa Expr) || infarg.head != :call
24+
# Workaround for an oddity in @inferred
25+
infarg = :(identity($infarg))
26+
end
27+
ex.args[2] = :(@inferred($infarg))
28+
esc(:(@test $ex))
29+
end
30+
31+
@testset "@testinf" begin
32+
@testinf [1,2] == [1,2]
33+
x = [1,2]
34+
@testinf x == [1,2]
35+
@testinf (@SVector [1,2]) == (@SVector [1,2])
36+
end

0 commit comments

Comments
 (0)