Skip to content

Commit db95a2a

Browse files
authored
stable pinv least-squares via rtol,atol in svd, ldiv! (#1387)
1 parent db48f0d commit db95a2a

File tree

3 files changed

+99
-35
lines changed

3 files changed

+99
-35
lines changed

src/dense.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,10 +1748,20 @@ both with the value of `M` and the intended application of the pseudoinverse.
17481748
The default relative tolerance is `n*ϵ`, where `n` is the size of the smallest
17491749
dimension of `M`, and `ϵ` is the [`eps`](@ref) of the element type of `M`.
17501750
1751-
For inverting dense ill-conditioned matrices in a least-squares sense,
1752-
`rtol = sqrt(eps(real(float(oneunit(eltype(M))))))` is recommended.
1751+
For solving dense, ill-conditioned equations in a least-square sense, it
1752+
is better to *not* explicitly form the pseudoinverse matrix, since this
1753+
can lead to numerical instability at low tolerances. The default `M \\ b`
1754+
algorithm instead uses pivoted QR factorization ([`qr`](@ref)). To use an
1755+
SVD-based algorithm, it is better to employ the SVD directly via `svd(M; rtol, atol) \\ b`
1756+
or `ldiv!(svd(M), b; rtol, atol)`.
17531757
1754-
For more information, see [^issue8859], [^B96], [^S84], [^KY88].
1758+
One can also pass `M = svd(A)` as the argument to `pinv` in order to re-use
1759+
an existing [`SVD`](@ref) factorization.
1760+
1761+
!!! compat "Julia 1.13"
1762+
Passing an `SVD` object to `pinv` requires Julia 1.13 or later.
1763+
1764+
For more information, see [^pr1387], [^B96], [^S84], [^KY88].
17551765
17561766
# Examples
17571767
```jldoctest
@@ -1771,15 +1781,15 @@ julia> M * N
17711781
4.44089e-16 1.0
17721782
```
17731783
1774-
[^issue8859]: Issue 8859, "Fix least squares", [https://github.com/JuliaLang/julia/pull/8859](https://github.com/JuliaLang/julia/pull/8859)
1784+
[^pr1387]: PR 1387, "stable pinv least-squares", [LinearAlgebra.jl#1387](https://github.com/JuliaLang/LinearAlgebra.jl/pull/1387)
17751785
17761786
[^B96]: Åke Björck, "Numerical Methods for Least Squares Problems", SIAM Press, Philadelphia, 1996, "Other Titles in Applied Mathematics", Vol. 51. [doi:10.1137/1.9781611971484](http://epubs.siam.org/doi/book/10.1137/1.9781611971484)
17771787
17781788
[^S84]: G. W. Stewart, "Rank Degeneracy", SIAM Journal on Scientific and Statistical Computing, 5(2), 1984, 403-413. [doi:10.1137/0905030](http://epubs.siam.org/doi/abs/10.1137/0905030)
17791789
17801790
[^KY88]: Konstantinos Konstantinides and Kung Yao, "Statistical analysis of effective singular values in matrix rank determination", IEEE Transactions on Acoustics, Speech and Signal Processing, 36(5), 1988, 757-763. [doi:10.1109/29.1585](https://doi.org/10.1109/29.1585)
17811791
"""
1782-
function pinv(A::AbstractMatrix{T}; atol::Real = 0.0, rtol::Real = (eps(real(float(oneunit(T))))*min(size(A)...))*iszero(atol)) where T
1792+
function pinv(A::AbstractMatrix{T}; atol::Real=0, rtol::Real = (eps(real(float(oneunit(T))))*min(size(A)...))*iszero(atol)) where T
17831793
m, n = size(A)
17841794
Tout = typeof(zero(T)/sqrt(oneunit(T) + oneunit(T)))
17851795
if m == 0 || n == 0
@@ -1847,7 +1857,7 @@ julia> nullspace(M, atol=0.95)
18471857
1.0
18481858
```
18491859
"""
1850-
function nullspace(A::AbstractVecOrMat; atol::Real = 0.0, rtol::Real = (min(size(A, 1), size(A, 2))*eps(real(float(oneunit(eltype(A))))))*iszero(atol))
1860+
function nullspace(A::AbstractVecOrMat; atol::Real=0, rtol::Real = (min(size(A, 1), size(A, 2))*eps(real(float(oneunit(eltype(A))))))*iszero(atol))
18511861
m, n = size(A, 1), size(A, 2)
18521862
(m == 0 || n == 0) && return Matrix{eigtype(eltype(A))}(I, n, n)
18531863
SVD = svd(A; full=true)

src/svd.jl

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,22 @@ default_svd_alg(A) = DivideAndConquer()
9292

9393

9494
"""
95-
svd!(A; full::Bool = false, alg::Algorithm = default_svd_alg(A)) -> SVD
95+
svd!(A; full::Bool = false, alg::Algorithm = default_svd_alg(A), atol::Real=0, rtol::Real=0) -> SVD
9696
9797
`svd!` is the same as [`svd`](@ref), but saves space by
9898
overwriting the input `A`, instead of creating a copy. See documentation of [`svd`](@ref) for details.
9999
"""
100-
function svd!(A::StridedMatrix{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where {T<:BlasFloat}
100+
function svd!(A::StridedMatrix{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A), atol::Real=0, rtol::Real=0) where {T<:BlasFloat}
101101
m, n = size(A)
102102
if m == 0 || n == 0
103103
u, s, vt = (Matrix{T}(I, m, full ? m : n), real(zeros(T,0)), Matrix{T}(I, n, n))
104104
else
105105
u, s, vt = _svd!(A, full, alg)
106106
end
107-
SVD(u, s, vt)
107+
s[_count_svdvals(s, atol, rtol)+1:end] .= 0
108+
return SVD(u, s, vt)
108109
end
109-
function svd!(A::StridedVector{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where {T<:BlasFloat}
110+
function svd!(A::StridedVector{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A), atol::Real=0, rtol::Real=0) where {T<:BlasFloat}
110111
m = length(A)
111112
normA = norm(A)
112113
if iszero(normA)
@@ -116,6 +117,7 @@ function svd!(A::StridedVector{T}; full::Bool = false, alg::Algorithm = default_
116117
return SVD(reshape(A, (m, 1)), [normA], ones(T, 1, 1))
117118
else
118119
u, s, vt = _svd!(reshape(A, (m, 1)), full, alg)
120+
s[_count_svdvals(s, atol, rtol)+1:end] .= 0
119121
return SVD(u, s, vt)
120122
end
121123
end
@@ -129,10 +131,15 @@ function _svd!(A::StridedMatrix{T}, full::Bool, alg::QRIteration) where {T<:Blas
129131
u, s, vt = LAPACK.gesvd!(c, c, A)
130132
end
131133

132-
134+
# count positive singular values S ≥ given tolerances, S assumed sorted
135+
function _count_svdvals(S, atol::Real, rtol::Real)
136+
isempty(S) && return 0
137+
tol = max(rtol * S[1], atol)
138+
return iszero(S[1]) ? 0 : searchsortedlast(S, tol, rev=true)
139+
end
133140

134141
"""
135-
svd(A; full::Bool = false, alg::Algorithm = default_svd_alg(A)) -> SVD
142+
svd(A; full::Bool = false, alg::Algorithm = default_svd_alg(A), atol::Real=0, rtol::Real=0) -> SVD
136143
137144
Compute the singular value decomposition (SVD) of `A` and return an `SVD` object.
138145
@@ -151,11 +158,18 @@ number of singular values.
151158
152159
`alg` specifies which algorithm and LAPACK method to use for SVD:
153160
- `alg = LinearAlgebra.DivideAndConquer()` (default): Calls `LAPACK.gesdd!`.
154-
- `alg = LinearAlgebra.QRIteration()`: Calls `LAPACK.gesvd!` (typically slower but more accurate) .
161+
- `alg = LinearAlgebra.QRIteration()`: Calls `LAPACK.gesvd!` (typically slower but more accurate).
162+
163+
The `atol` and `rtol` parameters specify optional tolerances to truncate the SVD,
164+
dropping (setting to zero) singular values less than `max(atol, rtol*σ₁)` where
165+
`σ₁` is the largest singular value of `A`.
155166
156167
!!! compat "Julia 1.3"
157168
The `alg` keyword argument requires Julia 1.3 or later.
158169
170+
!!! compat "Julia 1.13"
171+
The `atol` and `rtol` arguments require Julia 1.13 or later.
172+
159173
# Examples
160174
```jldoctest
161175
julia> A = rand(4,3);
@@ -176,25 +190,25 @@ julia> Uonly == U
176190
true
177191
```
178192
"""
179-
function svd(A::AbstractVecOrMat{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where {T}
180-
svd!(eigencopy_oftype(A, eigtype(T)), full = full, alg = alg)
193+
function svd(A::AbstractVecOrMat{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A), atol::Real=0, rtol::Real=0) where {T}
194+
svd!(eigencopy_oftype(A, eigtype(T)); full, alg, atol, rtol)
181195
end
182-
function svd(A::AbstractVecOrMat{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A)) where {T <: Union{Float16,Complex{Float16}}}
183-
A = svd!(eigencopy_oftype(A, eigtype(T)), full = full, alg = alg)
196+
function svd(A::AbstractVecOrMat{T}; full::Bool = false, alg::Algorithm = default_svd_alg(A), atol::Real=0, rtol::Real=0) where {T <: Union{Float16,Complex{Float16}}}
197+
A = svd!(eigencopy_oftype(A, eigtype(T)); full, alg, atol, rtol)
184198
return SVD{T}(A)
185199
end
186-
function svd(x::Number; full::Bool = false, alg::Algorithm = default_svd_alg(x))
187-
SVD(x == 0 ? fill(one(x), 1, 1) : fill(x/abs(x), 1, 1), [abs(x)], fill(one(x), 1, 1))
200+
function svd(x::Number; full::Bool = false, alg::Algorithm = default_svd_alg(x), atol::Real=0, rtol::Real=0)
201+
SVD(iszero(x) || abs(x) < atol ? fill(one(x), 1, 1) : fill(x/abs(x), 1, 1), [abs(x)], fill(one(x), 1, 1))
188202
end
189-
function svd(x::Integer; full::Bool = false, alg::Algorithm = default_svd_alg(x))
190-
svd(float(x), full = full, alg = alg)
203+
function svd(x::Integer; full::Bool = false, alg::Algorithm = default_svd_alg(x), atol::Real=0, rtol::Real=0)
204+
svd(float(x); full, alg, atol, rtol)
191205
end
192-
function svd(A::Adjoint; full::Bool = false, alg::Algorithm = default_svd_alg(A))
193-
s = svd(A.parent, full = full, alg = alg)
206+
function svd(A::Adjoint; full::Bool = false, alg::Algorithm = default_svd_alg(A), atol::Real=0, rtol::Real=0)
207+
s = svd(A.parent; full, alg, atol, rtol)
194208
return SVD(s.Vt', s.S, s.U')
195209
end
196-
function svd(A::Transpose; full::Bool = false, alg::Algorithm = default_svd_alg(A))
197-
s = svd(A.parent, full = full, alg = alg)
210+
function svd(A::Transpose; full::Bool = false, alg::Algorithm = default_svd_alg(A), atol::Real=0, rtol::Real=0)
211+
s = svd(A.parent; full, alg, atol, rtol)
198212
return SVD(transpose(s.Vt), s.S, transpose(s.U))
199213
end
200214

@@ -248,7 +262,7 @@ svdvals(S::SVD{<:Any,T}) where {T} = (S.S)::Vector{T}
248262
"""
249263
rank(S::SVD{<:Any, T}; atol::Real=0, rtol::Real=min(n,m)*ϵ) where {T}
250264
251-
Compute the numerical rank of the SVD object `S` by counting how many singular values are greater
265+
Compute the numerical rank of the SVD object `S` by counting how many singular values are greater
252266
than `max(atol, rtol*σ₁)` where `σ₁` is the largest calculated singular value.
253267
This is equivalent to the default [`rank(::AbstractMatrix)`](@ref) method except that it re-uses an existing SVD factorization.
254268
`atol` and `rtol` are the absolute and relative tolerances, respectively.
@@ -258,25 +272,51 @@ and `ϵ` is the [`eps`](@ref) of the element type of `S`.
258272
!!! compat "Julia 1.12"
259273
The `rank(::SVD)` method requires at least Julia 1.12.
260274
"""
261-
function rank(S::SVD; atol::Real = 0.0, rtol::Real = (min(size(S)...)*eps(real(float(eltype(S))))))
275+
function rank(S::SVD; atol::Real=0, rtol::Real = (min(size(S)...)*eps(real(float(eltype(S))))))
262276
tol = max(atol, rtol*S.S[1])
263277
count(>(tol), S.S)
264278
end
265279

266280
### SVD least squares ###
267-
function ldiv!(A::SVD{T}, B::AbstractVecOrMat) where T
268-
m, n = size(A)
269-
k = searchsortedlast(A.S, eps(real(T))*A.S[1], rev=true)
270-
mul!(view(B, 1:n, :), view(A.Vt, 1:k, :)', view(A.S, 1:k) .\ (view(A.U, :, 1:k)' * _cut_B(B, 1:m)))
281+
"""
282+
ldiv!(F::SVD, B; atol::Real=0, rtol::Real=atol>0 ? 0 : n*ϵ)
283+
284+
Given the SVD `F` of an ``m \\times n`` matrix, multiply the first ``m`` rows of `B` in-place
285+
by the Moore-Penrose pseudoinverse, storing the result in the first ``n`` rows of `B`, returning `B`.
286+
This is equivalent to a least-squares solution (for ``m \\ge n``) or a minimum-norm solution (for ``m \\le n``).
287+
288+
Similar to the [`pinv`](@ref) function, the solution can be regularized by truncating the SVD,
289+
dropping any singular values less than `max(atol, rtol*σ₁)` where `σ₁` is the largest singular value.
290+
The default relative tolerance is `n*ϵ`, where `n` is the size of the smallest dimension of `M`, and
291+
`ϵ` is the [`eps`](@ref) of the element type of `M`.
292+
293+
!!! compat "Julia 1.13"
294+
The `atol` and `rtol` arguments require Julia 1.13 or later.
295+
"""
296+
function ldiv!(F::SVD{T}, B::AbstractVecOrMat; atol::Real=0, rtol::Real = (eps(real(float(oneunit(T))))*min(size(F)...))*iszero(atol)) where T
297+
m, n = size(F)
298+
k = _count_svdvals(F.S, atol, rtol)
299+
if k == 0
300+
B[1:n, :] .= 0
301+
else
302+
temp = view(F.U, :, 1:k)' * _cut_B(B, 1:m)
303+
ldiv!(Diagonal(view(F.S, 1:k)), temp)
304+
mul!(view(B, 1:n, :), view(F.Vt, 1:k, :)', temp)
305+
end
271306
return B
272307
end
273308

274-
function inv(F::SVD{T}) where T
309+
function pinv(F::SVD{T}; atol::Real=0, rtol::Real = (eps(real(float(oneunit(T))))*min(size(F)...))*iszero(atol)) where T
310+
k = _count_svdvals(F.S, atol, rtol)
311+
@views (F.S[1:k] .\ F.Vt[1:k, :])' * F.U[:,1:k]'
312+
end
313+
314+
function inv(F::SVD)
315+
# TODO: checksquare(F)
275316
@inbounds for i in eachindex(F.S)
276317
iszero(F.S[i]) && throw(SingularException(i))
277318
end
278-
k = searchsortedlast(F.S, eps(real(T))*F.S[1], rev=true)
279-
@views (F.S[1:k] .\ F.Vt[1:k, :])' * F.U[:,1:k]'
319+
pinv(F; rtol=eps(real(eltype(F))))
280320
end
281321

282322
size(A::SVD, dim::Integer) = dim == 1 ? size(A.U, dim) : size(A.Vt, dim)

test/svd.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, QRPivoted
5353
@test sf2.U*Diagonal(sf2.S)*sf2.Vt' m2
5454

5555
@test ldiv!([0., 0.], svd(Matrix(I, 2, 2)), [1., 1.]) [1., 1.]
56+
# @test_throws DimensionMismatch inv(svd(Matrix(I, 3, 2)))
5657
@test inv(svd(Matrix(I, 2, 2))) I
5758
@test inv(svd([1 2; 3 4])) [-2.0 1.0; 1.5 -0.5]
5859
@test inv(svd([1 0 1; 0 1 0])) [0.5 0.0; 0.0 1.0; 0.5 0.0]
@@ -235,7 +236,20 @@ end
235236
@test Uc * diagm(0=>Sc) * transpose(V) complex.(A) rtol=1e-3
236237
end
237238

238-
@testset "Issue 40944. ldiV!(SVD) should update rhs" begin
239+
@testset "SVD pinv and truncation" begin
240+
m, n = 10,5
241+
A = randn(m,n) * [1/(i+j-1) for i = 1:n, j=1:n] # badly conditioned Hilbert matrix
242+
@test pinv(A) pinv(svd(A)) rtol=1e-13
243+
pinv_3 = pinv(A, rtol=1e-3)
244+
@test pinv_3 pinv(svd(A), rtol=1e-3) rtol=1e-13
245+
@test pinv_3 pinv(svd(A, rtol=1e-3)) rtol=1e-13
246+
b = float([1:m;]) # arbitrary rhs
247+
@test pinv_3 * b svd(A, rtol=1e-3) \ b rtol=1e-13
248+
@test pinv_3 * b ldiv!(svd(A), copy(b), rtol=1e-3)[1:n] rtol=1e-13
249+
@test pinv(A, atol=100) == pinv(svd(A), atol=100) == pinv(svd(A, atol=100)) == zeros(5,10)
250+
end
251+
252+
@testset "Issue 40944. ldiv!(SVD) should update rhs" begin
239253
F = svd(randn(2, 2))
240254
b = randn(2)
241255
x = ldiv!(F, b)

0 commit comments

Comments
 (0)