Skip to content

Commit 3819c74

Browse files
authored
Merge pull request #231 from JuliaArrays/matrix-factorizations
Matrix factorizations
2 parents c089124 + ac3ed4d commit 3819c74

File tree

10 files changed

+185
-1
lines changed

10 files changed

+185
-1
lines changed

src/StaticArrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import Base: getindex, setindex!, size, similar, vec, show,
99
mapreduce, broadcast, broadcast!, conj, transpose, ctranspose,
1010
hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill,
1111
fill!, det, inv, eig, eigvals, expm, sqrtm, trace, vecnorm, norm, dot, diagm, diag,
12+
lu, svd, svdvals, svdfact,
1213
sum, diff, prod, count, any, all, minimum,
1314
maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!,
1415
randexp!, normalize, normalize!, read, read!, write
@@ -93,6 +94,8 @@ include("eigen.jl")
9394
include("expm.jl")
9495
include("sqrtm.jl")
9596
include("cholesky.jl")
97+
include("svd.jl")
98+
include("lu.jl")
9699
include("deque.jl")
97100
include("io.jl")
98101

src/linalg.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,4 @@ end
323323

324324
@inline Base.LinAlg.Symmetric(A::StaticMatrix, uplo::Char='U') = (Base.LinAlg.checksquare(A);Symmetric{eltype(A),typeof(A)}(A, uplo))
325325
@inline Base.LinAlg.Hermitian(A::StaticMatrix, uplo::Char='U') = (Base.LinAlg.checksquare(A);Hermitian{eltype(A),typeof(A)}(A, uplo))
326+

src/lu.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# LU decomposition
2+
function lu(A::StaticMatrix, pivot::Union{Type{Val{false}},Type{Val{true}}}=Val{true})
3+
L,U,p = _lu(Size(A), A, pivot)
4+
(L,U,p)
5+
end
6+
7+
# For the square version, return explicit lower and upper triangular matrices.
8+
# We would do this for the rectangular case too, but Base doesn't support that.
9+
function lu(A::StaticMatrix{N,N}, pivot::Union{Type{Val{false}},Type{Val{true}}}=Val{true}) where {N}
10+
L,U,p = _lu(Size(A), A, pivot)
11+
(LowerTriangular(L), UpperTriangular(U), p)
12+
end
13+
14+
15+
@inline function _lu(::Size{S}, A::StaticMatrix, pivot) where {S}
16+
# For now, just call through to Base.
17+
# TODO: statically sized LU without allocations!
18+
f = lufact(Matrix(A), pivot)
19+
T = eltype(A)
20+
# Trick to get the output eltype - can't rely on the result of f[:L] as
21+
# it's not type inferrable.
22+
T2 = typeof((one(T)*zero(T) + zero(T))/one(T))
23+
L = similar_type(A, T2, Size(Size(A)[1], diagsize(A)))(f[:L])
24+
U = similar_type(A, T2, Size(diagsize(A), Size(A)[2]))(f[:U])
25+
p = similar_type(A, Int, Size(Size(A)[1]))(f[:p])
26+
(L,U,p)
27+
end
28+
29+
# Base.lufact() interface is fairly inherently type unstable. Punt on
30+
# implementing that, for now...

src/svd.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Singular Value Decomposition
2+
3+
# We need our own SVD factorization struct, as Base.LinAlg.SVD assumes
4+
# Base.Vector for `S`, and that the `U` and `Vt` have the same
5+
struct SVD{T,TU,TS,TVt} <: Factorization{T}
6+
U::TU
7+
S::TS
8+
Vt::TVt
9+
end
10+
SVD(U::AbstractArray{T}, S::AbstractVector, Vt::AbstractArray{T}) where {T} = SVD{T,typeof(U),typeof(S),typeof(Vt)}(U, S, Vt)
11+
12+
getindex(::SVD, ::Symbol) = error("In order to avoid type instability, StaticArrays.SVD doesn't support indexing the output of svdfact with a symbol. Instead, you can access the fields of the factorization directly as f.U, f.S, and f.Vt")
13+
14+
function svdvals(A::StaticMatrix)
15+
sv = svdvals(Matrix(A))
16+
# We should be using `T2=eltype(sv)`, but it's not inferrable for complex
17+
# eltypes. See https://github.com/JuliaLang/julia/pull/22443
18+
T = eltype(A)
19+
T2 = promote_type(Float32, real(typeof(one(T)/norm(one(T)))))
20+
similar_type(A, T2, Size(diagsize(A)))(sv)
21+
end
22+
23+
function svdfact(A::StaticMatrix)
24+
# "Thin" SVD only for now.
25+
f = svdfact(Matrix(A))
26+
U = similar_type(A, eltype(f.U), Size(Size(A)[1], diagsize(A)))(f.U)
27+
S = similar_type(A, eltype(f.S), Size(diagsize(A)))(f.S)
28+
Vt = similar_type(A, eltype(f.Vt), Size(diagsize(A), Size(A)[2]))(f.Vt)
29+
SVD(U,S,Vt)
30+
end
31+
32+
function svd(A::StaticMatrix)
33+
# Need our own version of `svd()`, as `Base` passes the `thin` argument
34+
# which makes the resulting dimensions uninferrable.
35+
f = svdfact(A)
36+
(f.U, f.S, f.Vt')
37+
end

src/traits.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,7 @@ end
123123
@noinline function _throw_size_mismatch(as...)
124124
throw(DimensionMismatch("Sizes $(map(_size, as)) of input arrays do not match"))
125125
end
126+
127+
# Return the "diagonal size" of a matrix - the minimum of the two dimensions
128+
diagsize(A::StaticMatrix) = diagsize(Size(A))
129+
@pure diagsize(::Size{S}) where {S} = min(S...)

test/linalg.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using StaticArrays, Base.Test
2+
13
@testset "Linear algebra" begin
24

35
@testset "SVector as a (mathematical) vector space" begin
@@ -138,7 +140,7 @@
138140
@test trace(@SMatrix [1.0 2.0; 3.0 4.0]) === 5.0
139141
@test_throws DimensionMismatch trace(@SMatrix rand(5,4))
140142
end
141-
143+
142144
@testset "size zero" begin
143145
@test vecdot(SVector{0, Float64}(()), SVector{0, Float64}(())) === 0.
144146
@test vecnorm(SVector{0, Float64}(())) === 0.

test/lu.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using StaticArrays, Base.Test
2+
3+
@testset "LU decomposition" begin
4+
# Square case
5+
m22 = @SMatrix [1 2; 3 4]
6+
@test @inferred(lu(m22)) isa Tuple{LowerTriangular{Float64,SMatrix{2,2,Float64,4}}, UpperTriangular{Float64,SMatrix{2,2,Float64,4}}, SVector{2,Int}}
7+
@test lu(m22)[1]::LowerTriangular{<:Any,<:StaticMatrix} lu(Matrix(m22))[1]
8+
@test lu(m22)[2]::UpperTriangular{<:Any,<:StaticMatrix} lu(Matrix(m22))[2]
9+
@test lu(m22)[3]::StaticVector lu(Matrix(m22))[3]
10+
11+
# Rectangular case
12+
m23 = @SMatrix Float64[3 9 4; 6 6 2]
13+
@test @inferred(lu(m23)) isa Tuple{SMatrix{2,2,Float64,4}, SMatrix{2,3,Float64,6}, SVector{2,Int}}
14+
@test lu(m23)[1] lu(Matrix(m23))[1]
15+
@test lu(m23)[2] lu(Matrix(m23))[2]
16+
@test lu(m23)[3] lu(Matrix(m23))[3]
17+
18+
@test lu(m23')[1] lu(Matrix(m23'))[1]
19+
@test lu(m23')[2] lu(Matrix(m23'))[2]
20+
@test lu(m23')[3] lu(Matrix(m23'))[3]
21+
end

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using StaticArrays
22
using Base.Test
33

4+
include("testutil.jl")
5+
46
@testset "StaticArrays" begin
57
include("SVector.jl")
68
include("MVector.jl")
@@ -31,6 +33,7 @@ using Base.Test
3133
include("chol.jl")
3234
include("deque.jl")
3335
include("io.jl")
36+
include("svd.jl")
3437

3538
include("fixed_size_arrays.jl")
3639
end

test/svd.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
using StaticArrays, Base.Test
2+
3+
@testset "SVD factorization" begin
4+
m3 = @SMatrix Float64[3 9 4; 6 6 2; 3 7 9]
5+
m3c = Complex128.(m3)
6+
m23 = @SMatrix Float64[3 9 4; 6 6 2]
7+
8+
@testset "svd" begin
9+
@testinf svdvals(@SMatrix [2 0; 0 0])::StaticVector [2, 0]
10+
@testinf svdvals((@SMatrix [2 -2; 1 1]) / sqrt(2)) [2, 1]
11+
12+
@testinf svdvals(m3) svdvals(Matrix(m3))
13+
@testinf svdvals(m3c) isa SVector{3,Float64}
14+
15+
@testinf svd(m3)[1]::StaticMatrix svd(Matrix(m3))[1]
16+
@testinf svd(m3)[2]::StaticVector svd(Matrix(m3))[2]
17+
@testinf svd(m3)[3]::StaticMatrix svd(Matrix(m3))[3]
18+
end
19+
20+
@testset "svdfact" begin
21+
@test_throws ErrorException svdfact(@SMatrix [1 0; 0 1])[:U]
22+
23+
@testinf svdfact(@SMatrix [2 0; 0 0]).U === eye(SMatrix{2,2})
24+
@testinf svdfact(@SMatrix [2 0; 0 0]).S === SVector(2.0, 0.0)
25+
@testinf svdfact(@SMatrix [2 0; 0 0]).Vt === eye(SMatrix{2,2})
26+
27+
@testinf svdfact((@SMatrix [2 -2; 1 1]) / sqrt(2)).U [-1 0; 0 1]
28+
@testinf svdfact((@SMatrix [2 -2; 1 1]) / sqrt(2)).S [2, 1]
29+
@testinf svdfact((@SMatrix [2 -2; 1 1]) / sqrt(2)).Vt [-1 1; 1 1]/sqrt(2)
30+
31+
@testinf svdfact(m23).U svdfact(Matrix(m23))[:U]
32+
@testinf svdfact(m23).S svdfact(Matrix(m23))[:S]
33+
@testinf svdfact(m23).Vt svdfact(Matrix(m23))[:Vt]
34+
35+
@testinf svdfact(m23').U svdfact(Matrix(m23'))[:U]
36+
@testinf svdfact(m23').S svdfact(Matrix(m23'))[:S]
37+
@testinf svdfact(m23').Vt svdfact(Matrix(m23'))[:Vt]
38+
39+
@testinf svdfact(m3c).U svdfact(Matrix(m3c))[:U]
40+
@testinf svdfact(m3c).S svdfact(Matrix(m3c))[:S]
41+
@testinf svdfact(m3c).Vt svdfact(Matrix(m3c))[:Vt]
42+
43+
@testinf svdfact(m3c).U isa SMatrix{3,3,Complex128}
44+
@testinf svdfact(m3c).S isa SVector{3,Float64}
45+
@testinf svdfact(m3c).Vt isa SMatrix{3,3,Complex128}
46+
end
47+
end

test/testutil.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
using Base.Test, StaticArrays
2+
3+
"""
4+
x ≊ y
5+
6+
Inexact equality comparison. Like `≈` this calls `isapprox`, but with a
7+
tighter tolerance of `rtol=10*eps()`. Input with "\\approxeq".
8+
"""
9+
(x,y) = isapprox(x, y, rtol=10*eps())
10+
11+
"""
12+
@testinf a op b
13+
14+
Test that the type of the first argument `a` is inferred, and that `a op b` is
15+
true. For example, the following are equivalent:
16+
17+
@testinf SVector(1,2) + SVector(1,2) == SVector(2,4)
18+
@test @inferred(SVector(1,2) + SVector(1,2)) == SVector(2,4)
19+
"""
20+
macro testinf(ex)
21+
@assert ex.head == :call
22+
infarg = ex.args[2]
23+
if !(infarg isa Expr) || infarg.head != :call
24+
# Workaround for an oddity in @inferred
25+
infarg = :(identity($infarg))
26+
end
27+
ex.args[2] = :(@inferred($infarg))
28+
esc(:(@test $ex))
29+
end
30+
31+
@testset "@testinf" begin
32+
@testinf [1,2] == [1,2]
33+
x = [1,2]
34+
@testinf x == [1,2]
35+
@testinf (@SVector [1,2]) == (@SVector [1,2])
36+
end

0 commit comments

Comments
 (0)