Skip to content

Commit 8d709a4

Browse files
hyrodiumthchr
andauthored
Add pinv(::StaticMatrix) (#873)
* add pinv * update for pinv * add doctests, resolve type unstability for pinv * add tests for pinv * add method for Diagonal{T,<:StaticVector} * add more tests * Update src/pinv.jl Co-authored-by: Thomas Christensen <tchr@mit.edu> * remove docstring for pinv * replace `@test` with `@test_broken` because of bug in Julia * replace istril && istriu with isdiag * Rename `_pinv_M` to `_pinv_diag` Co-authored-by: Thomas Christensen <tchr@mit.edu> * Rename `_pinv_M` to `_pinv_diag` * Rename `_pinv_V` to `_pinv_vector` * Update pinv tests for the pinv bug in Julia * replace unnecessary `@test_broken` with `@test` Co-authored-by: Thomas Christensen <tchr@mit.edu> * replace unnecessary `@test_broken` with `@test` Co-authored-by: Thomas Christensen <tchr@mit.edu> * fix pinv tests for Julia v1.6 Co-authored-by: Thomas Christensen <tchr@mit.edu>
1 parent a34cb17 commit 8d709a4

File tree

4 files changed

+145
-1
lines changed

4 files changed

+145
-1
lines changed

src/StaticArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ using Core.Compiler: return_type
1616
import Base: sqrt, exp, log, float, real
1717
using LinearAlgebra
1818
import LinearAlgebra: transpose, adjoint, dot, eigvals, eigen, lyap, tr,
19-
kron, diag, norm, dot, diagm, lu, svd, svdvals,
19+
kron, diag, norm, dot, diagm, lu, svd, svdvals, pinv,
2020
factorize, ishermitian, issymmetric, isposdef, issuccess, normalize,
2121
normalize!, Eigen, det, logdet, logabsdet, cross, diff, qr, \
2222
using LinearAlgebra: checksquare
@@ -143,6 +143,7 @@ include("qr.jl")
143143
include("deque.jl")
144144
include("flatten.jl")
145145
include("io.jl")
146+
include("pinv.jl")
146147

147148
include("precompile.jl")
148149
_precompile_()

src/pinv.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Moore-Penrose pseudoinverse
2+
3+
@inline function pinv(A::StaticMatrix{m,n,T} where m where n; atol::Real = 0.0, rtol::Real = (eps(real(float(one(T))))*min(size(A)...))*iszero(atol)) where T
4+
# This function is a StaticMatrix version of `LinearAlgebra.pinv`.
5+
S = typeof(zero(T)/sqrt(one(T) + one(T)))
6+
A_S = convert(similar_type(A,S),A)
7+
return _pinv(A_S, atol, rtol)
8+
end
9+
10+
@inline function _pinv(A::StaticMatrix{m,n,T}, atol::Real, rtol::Real) where T where m where n
11+
if m == 0 || n == 0
12+
return similar_type(A, Size(n,m))()
13+
end
14+
if isdiag(A)
15+
maxabsA = maximum(abs.(diag(A)))
16+
tol = max(rtol*maxabsA, atol)
17+
return _pinv_diag(A, tol)
18+
end
19+
ssvd = svd(A, full = false)
20+
tol = max(rtol*maximum(ssvd.S), atol)
21+
sinv = _pinv_vector(ssvd.S, tol)
22+
return ssvd.Vt'*SDiagonal(sinv)*ssvd.U'
23+
end
24+
25+
@inline function pinv(D::Diagonal{T,<:StaticVector}) where T
26+
V = D.diag
27+
S = typeof(zero(T)/sqrt(one(T) + one(T)))
28+
V_S = convert(similar_type(V,S),V)
29+
return Diagonal(_pinv_vector(V_S))
30+
end
31+
32+
@generated function _pinv_diag(A::StaticMatrix{m,n,T}, tol) where m where n where T
33+
minlen = min(m,n)
34+
exprs = [:(zero($T)) for i in 1:n, j in 1:m]
35+
for i in 1:minlen
36+
exprs[i,i] = :(ifelse(A[$i,$i] > tol, inv(A[$i,$i]), zero($T)))
37+
end
38+
return quote
39+
Base.@_inline_meta
40+
@inbounds return similar_type(A, Size($n, $m))(tuple($(exprs...)))
41+
end
42+
end
43+
44+
@generated function _pinv_vector(v::StaticVector{n,T}, tol) where n where T
45+
exprs = [
46+
:(ifelse(v[$i] > tol, inv(v[$i]), zero(T)))
47+
for i in 1:n
48+
]
49+
return quote
50+
Base.@_inline_meta
51+
@inbounds return similar_type(v, Size($n))(tuple($(exprs...)))
52+
end
53+
end
54+
55+
@generated function _pinv_vector(v::StaticVector{n,T}) where n where T
56+
exprs = [
57+
:(pinv(v[$i]))
58+
for i in 1:n
59+
]
60+
return quote
61+
Base.@_inline_meta
62+
@inbounds return similar_type(v, Size($n))(tuple($(exprs...)))
63+
end
64+
end

test/pinv.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
using StaticArrays, Test, LinearAlgebra
2+
3+
tol = 1e-13
4+
5+
@testset "Moore–Penrose inverse (Peseudo-inverse)" begin
6+
M1 = @SMatrix [1.5 1.3; 1.2 1.9]
7+
N1 = pinv(M1)
8+
@test norm(M1*N1*M1 - M1) < tol
9+
@test norm(N1*M1*N1 - N1) < tol
10+
@test N1 isa SMatrix{2,2,Float64}
11+
@test N1 pinv(Matrix(M1))
12+
13+
M2 = @SMatrix [1//2 0 0;3//2 5//3 8//7;9//4 -1//3 -8//7;0 0 0]
14+
N2 = pinv(M2)
15+
@test norm(M2*N2*M2 - M2) < tol
16+
@test norm(N2*M2*N2 - N2) < tol
17+
@test N2 isa SMatrix{3,4,Float64}
18+
@test N2 pinv(Matrix(M2))
19+
20+
M3 = SDiagonal(0,1,2,3)
21+
N3 = pinv(M3)
22+
@test norm(M3*N3*M3 - M3) < tol
23+
@test norm(N3*M3*N3 - N3) < tol
24+
@test N3 isa Diagonal{Float64, <:SVector{4,Float64}}
25+
@test N3 pinv(Matrix(M3))
26+
27+
M4 = @SMatrix randn(2,5)
28+
N4 = pinv(M4)
29+
@test norm(M4*N4*M4 - M4) < tol
30+
@test norm(N4*M4*N4 - N4) < tol
31+
@test N4 isa SMatrix{5,2,Float64}
32+
@test N4 pinv(Matrix(M4))
33+
34+
M5 = SMatrix{0,5,Int}()
35+
N5 = pinv(M5)
36+
@test norm(M5*N5*M5 - M5) < tol
37+
@test norm(N5*(M5*N5) - N5) < tol
38+
@test N5 isa SMatrix{5,0,Float64}
39+
@test N5 pinv(Matrix(M5))
40+
41+
M6 = @SMatrix [1/2 0 0;0 5/3 0;0 0 0;0 0 0]
42+
N6 = pinv(M6)
43+
@test norm(M6*N6*M6 - M6) < tol
44+
@test norm(N6*M6*N6 - N6) < tol
45+
@test N6 isa SMatrix{3,4,Float64}
46+
@test N6 I(3)/Matrix(M6)
47+
# @test N6 ≈ pinv(Matrix(M6)) # Fails on Julia ≥v1.7 https://github.com/JuliaLang/julia/issues/44234
48+
49+
M7 = M6'
50+
N7 = pinv(M7)
51+
@test norm(M7*N7*M7 - M7) < tol
52+
@test norm(N7*M7*N7 - N7) < tol
53+
@test N7 isa SMatrix{4,3,Float64}
54+
@test N7 I(4)/Matrix(M7)
55+
# @test N7 ≈ pinv(Matrix(M7)) # Fails on Julia ≥v1.7 https://github.com/JuliaLang/julia/issues/44234
56+
57+
M8 = @MMatrix [0.5 1.1 0.0;0.0 -2.8 0.0;0.0 0.0 0.0;0.0 0.0 0.0]
58+
N8 = pinv(M8)
59+
@test norm(M8*N8*M8 - M8) < tol
60+
@test norm(N8*M8*N8 - N8) < tol
61+
@test N8 isa MMatrix{3,4,Float64}
62+
@test N8 pinv(Matrix(M8))
63+
64+
M9 = M8'
65+
N9 = pinv(M9)
66+
@test norm(M9*N9*M9 - M9) < tol
67+
@test norm(N9*M9*N9 - N9) < tol
68+
@test N9 isa MMatrix{4,3,Float64}
69+
@test N9 pinv(Matrix(M9))
70+
71+
M10 = @SMatrix randn(3,3)
72+
N10 = pinv(M10)
73+
@test N10 inv(M10)
74+
@test norm(M10*N10*M10 - M10) < tol
75+
@test norm(N10*M10*N10 - N10) < tol
76+
@test N10 isa StaticMatrix
77+
@test N10 pinv(Matrix(M10))
78+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ addtests("matrix_multiply_add.jl")
5656
addtests("triangular.jl")
5757
addtests("det.jl")
5858
addtests("inv.jl")
59+
addtests("pinv.jl")
5960
addtests("solve.jl")
6061
addtests("eigen.jl")
6162
addtests("expm.jl")

0 commit comments

Comments
 (0)