Skip to content

Commit 2c3fe9b

Browse files
authored
return SVD from pinv(::SVD) (#1398)
As suggested by @andreasnoack in #1387 (comment), this changes the `pinv(::SVD)` function to instead return the SVD of the pseudo-inverse, rather than an explicit matrix, so that it can be applied stably. (No backwards-compatibility issue since the `pinv(::SVD)` method was introduced in #1387.)
1 parent 9e2ed1c commit 2c3fe9b

File tree

3 files changed

+21
-10
lines changed

3 files changed

+21
-10
lines changed

src/dense.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1756,7 +1756,8 @@ SVD-based algorithm, it is better to employ the SVD directly via `svd(M; rtol, a
17561756
or `ldiv!(svd(M), b; rtol, atol)`.
17571757
17581758
One can also pass `M = svd(A)` as the argument to `pinv` in order to re-use
1759-
an existing [`SVD`](@ref) factorization.
1759+
an existing [`SVD`](@ref) factorization. In this case, `pinv` will return
1760+
the SVD of the pseudo-inverse, which can be applied accurately, instead of an explicit matrix.
17601761
17611762
!!! compat "Julia 1.13"
17621763
Passing an `SVD` object to `pinv` requires Julia 1.13 or later.

src/svd.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,17 +308,22 @@ end
308308

309309
function pinv(F::SVD{T}; atol::Real=0, rtol::Real = (eps(real(float(oneunit(T))))*min(size(F)...))*iszero(atol)) where T
310310
k = _count_svdvals(F.S, atol, rtol)
311-
@views (F.S[1:k] .\ F.Vt[1:k, :])' * F.U[:,1:k]'
311+
@views SVD(copy(F.Vt[k:-1:1, :]'), inv.(F.S[k:-1:1]), copy(F.U[:,k:-1:1]'))
312312
end
313313

314314
function inv(F::SVD)
315315
checksquare(F)
316316
@inbounds for i in eachindex(F.S)
317317
iszero(F.S[i]) && throw(SingularException(i))
318318
end
319-
pinv(F; rtol=eps(real(eltype(F))))
319+
k = _count_svdvals(F.S, 0, eps(real(eltype(F))))
320+
return @views (F.S[1:k] .\ F.Vt[1:k, :])' * F.U[:,1:k]'
320321
end
321322

323+
# multiplying SVD by matrix/vector, mainly useful for pinv(::SVD) output
324+
(*)(F::SVD, A::AbstractVecOrMat{<:Number}) = F.U * (Diagonal(F.S) * (F.Vt * A))
325+
(*)(A::AbstractMatrix{<:Number}, F::SVD) = ((A*F.U) * Diagonal(F.S)) * F.Vt
326+
322327
size(A::SVD, dim::Integer) = dim == 1 ? size(A.U, dim) : size(A.Vt, dim)
323328
size(A::SVD) = (size(A, 1), size(A, 2))
324329

test/svd.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, QRPivoted
5656
@test_throws DimensionMismatch inv(svd(Matrix(I, 3, 2)))
5757
@test inv(svd(Matrix(I, 2, 2))) I
5858
@test inv(svd([1 2; 3 4])) [-2.0 1.0; 1.5 -0.5]
59-
@test pinv(svd([1 0 1; 0 1 0])) [0.5 0.0; 0.0 1.0; 0.5 0.0]
59+
@test Matrix(pinv(svd([1 0 1; 0 1 0]))) [0.5 0.0; 0.0 1.0; 0.5 0.0]
6060
@test_throws SingularException inv(svd([0 0; 0 0]))
6161
@test inv(svd([1+2im 3+4im; 5+6im 7+8im])) [-0.5 + 0.4375im 0.25 - 0.1875im; 0.375 - 0.3125im -0.125 + 0.0625im]
6262
end
@@ -239,14 +239,19 @@ end
239239
@testset "SVD pinv and truncation" begin
240240
m, n = 10,5
241241
A = randn(m,n) * [1/(i+j-1) for i = 1:n, j=1:n] # badly conditioned Hilbert matrix
242-
@test pinv(A) pinv(svd(A)) rtol=1e-13
242+
F = svd(A)
243+
@test pinv(A) Matrix(pinv(F)) rtol=1e-13
243244
pinv_3 = pinv(A, rtol=1e-3)
244-
@test pinv_3 pinv(svd(A), rtol=1e-3) rtol=1e-13
245-
@test pinv_3 pinv(svd(A, rtol=1e-3)) rtol=1e-13
245+
F_3 = svd(A, rtol=1e-3)
246+
@test pinv_3 Matrix(pinv(F, rtol=1e-3)) rtol=1e-13
247+
@test pinv_3 Matrix(pinv(F_3)) rtol=1e-13
246248
b = float([1:m;]) # arbitrary rhs
247-
@test pinv_3 * b svd(A, rtol=1e-3) \ b rtol=1e-13
248-
@test pinv_3 * b ldiv!(svd(A), copy(b), rtol=1e-3)[1:n] rtol=1e-13
249-
@test pinv(A, atol=100) == pinv(svd(A), atol=100) == pinv(svd(A, atol=100)) == zeros(5,10)
249+
@test pinv_3 * b F_3 \ b rtol=1e-13
250+
@test pinv_3 * b pinv(F_3) * b rtol=1e-13
251+
@test pinv_3 * b ldiv!(F, copy(b), rtol=1e-3)[1:n] rtol=1e-13
252+
c = float([1:n;]) # arbitrary rhs
253+
@test c' * pinv_3 c' * pinv(F_3) rtol=1e-13
254+
@test pinv(A, atol=100) == Matrix(pinv(F, atol=100)) == Matrix(pinv(svd(A, atol=100))) == zeros(5,10)
250255
end
251256

252257
@testset "Issue 40944. ldiv!(SVD) should update rhs" begin

0 commit comments

Comments
 (0)