Skip to content

Commit 12217c7

Browse files
committed
Various fixes for complex SVD
1 parent 0fce2a1 commit 12217c7

File tree

3 files changed

+25
-10
lines changed

3 files changed

+25
-10
lines changed

src/StaticArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Base: getindex, setindex!, size, similar, vec, show,
99
mapreduce, broadcast, broadcast!, conj, transpose, ctranspose,
1010
hcat, vcat, ones, zeros, eye, one, cross, vecdot, reshape, fill,
1111
fill!, det, inv, eig, eigvals, expm, sqrtm, trace, vecnorm, norm, dot, diagm, diag,
12-
lu, svd, svdvals!, svdfact,
12+
lu, svd, svdvals, svdfact,
1313
sum, diff, prod, count, any, all, minimum,
1414
maximum, extrema, mean, copy, rand, randn, randexp, rand!, randn!,
1515
randexp!, normalize, normalize!, read, read!, write

src/svd.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@ getindex(::SVD, ::Symbol) = error("In order to avoid type instability, StaticArr
1313

1414
function svdvals(A::StaticMatrix)
1515
sv = svdvals(Matrix(A))
16-
similar_type(A, eltype(sv), Size(diagsize(A)))(sv)
16+
# We should be using `T2=eltype(sv)`, but it's not inferrable for complex
17+
# eltypes. See https://github.com/JuliaLang/julia/pull/22443
18+
T = eltype(A)
19+
T2 = promote_type(Float32, real(typeof(one(T)/norm(one(T)))))
20+
similar_type(A, T2, Size(diagsize(A)))(sv)
1721
end
1822

1923
function svdfact(A::StaticMatrix)
2024
# "Thin" SVD only for now.
2125
f = svdfact(Matrix(A))
2226
U = similar_type(A, eltype(f.U), Size(Size(A)[1], diagsize(A)))(f.U)
23-
S = similar_type(A, Size(diagsize(A)))(f.S)
27+
S = similar_type(A, eltype(f.S), Size(diagsize(A)))(f.S)
2428
Vt = similar_type(A, eltype(f.Vt), Size(diagsize(A), Size(A)[2]))(f.Vt)
2529
SVD(U,S,Vt)
2630
end

test/svd.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,47 @@
11
using StaticArrays, Base.Test
22

33
@testset "SVD factorization" begin
4+
m3 = @SMatrix Float64[3 9 4; 6 6 2; 3 7 9]
5+
m3c = Complex128.(m3)
6+
m23 = @SMatrix Float64[3 9 4; 6 6 2]
7+
48
@testset "svd" begin
59
@testinf svdvals(@SMatrix [2 0; 0 0])::StaticVector [2, 0]
610
@testinf svdvals((@SMatrix [2 -2; 1 1]) / sqrt(2)) [2, 1]
711

8-
m3 = @SMatrix Float64[3 9 4; 6 6 2; 3 7 9]
9-
@testinf svdvals(m3) svdvals(Matrix(m3))
12+
@testinf svdvals(m3) svdvals(Matrix(m3))
13+
@testinf svdvals(m3c) isa SVector{3,Float64}
1014

11-
@testinf svd(m3)[1]::StaticMatrix svd(Matrix(m3))[1]
12-
@testinf svd(m3)[2]::StaticVector svd(Matrix(m3))[2]
13-
@testinf svd(m3)[3]::StaticMatrix svd(Matrix(m3))[3]
15+
@testinf svd(m3)[1]::StaticMatrix svd(Matrix(m3))[1]
16+
@testinf svd(m3)[2]::StaticVector svd(Matrix(m3))[2]
17+
@testinf svd(m3)[3]::StaticMatrix svd(Matrix(m3))[3]
1418
end
1519

1620
@testset "svdfact" begin
1721
@test_throws ErrorException svdfact(@SMatrix [1 0; 0 1])[:U]
1822

1923
@testinf svdfact(@SMatrix [2 0; 0 0]).U === eye(SMatrix{2,2})
20-
@testinf svdfact(@SMatrix [2 0; 0 0]).S === SVector(2, 0)
24+
@testinf svdfact(@SMatrix [2 0; 0 0]).S === SVector(2.0, 0.0)
2125
@testinf svdfact(@SMatrix [2 0; 0 0]).Vt === eye(SMatrix{2,2})
2226

2327
@testinf svdfact((@SMatrix [2 -2; 1 1]) / sqrt(2)).U [-1 0; 0 1]
2428
@testinf svdfact((@SMatrix [2 -2; 1 1]) / sqrt(2)).S [2, 1]
2529
@testinf svdfact((@SMatrix [2 -2; 1 1]) / sqrt(2)).Vt [-1 1; 1 1]/sqrt(2)
2630

27-
m23 = @SMatrix Float64[3 9 4; 6 6 2]
2831
@testinf svdfact(m23).U svdfact(Matrix(m23))[:U]
2932
@testinf svdfact(m23).S svdfact(Matrix(m23))[:S]
3033
@testinf svdfact(m23).Vt svdfact(Matrix(m23))[:Vt]
3134

3235
@testinf svdfact(m23').U svdfact(Matrix(m23'))[:U]
3336
@testinf svdfact(m23').S svdfact(Matrix(m23'))[:S]
3437
@testinf svdfact(m23').Vt svdfact(Matrix(m23'))[:Vt]
38+
39+
@testinf svdfact(m3c).U svdfact(Matrix(m3c))[:U]
40+
@testinf svdfact(m3c).S svdfact(Matrix(m3c))[:S]
41+
@testinf svdfact(m3c).Vt svdfact(Matrix(m3c))[:Vt]
42+
43+
@testinf svdfact(m3c).U isa SMatrix{3,3,Complex128}
44+
@testinf svdfact(m3c).S isa SVector{3,Float64}
45+
@testinf svdfact(m3c).Vt isa SMatrix{3,3,Complex128}
3546
end
3647
end

0 commit comments

Comments
 (0)