Skip to content

Commit d406524

Browse files
authored
implements a rank(::SVD) method and adds unit tests. fixes #1126 (#1127)
* added rank method for SVD object alongwith unit tests for the method. issue #1126 * simplified DOCSTRING for rank(::SVD) * simplified rank(::SVD) implementation
1 parent f13f940 commit d406524

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

src/svd.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,24 @@ svdvals(A::AbstractVector{T}) where {T} = [convert(eigtype(T), norm(A))]
245245
svdvals(x::Number) = abs(x)
246246
svdvals(S::SVD{<:Any,T}) where {T} = (S.S)::Vector{T}
247247

248+
"""
249+
rank(S::SVD{<:Any, T}; atol::Real=0, rtol::Real=min(n,m)*ϵ) where {T}
250+
251+
Compute the numerical rank of the SVD object `S` by counting how many singular values are greater
252+
than `max(atol, rtol*σ₁)` where `σ₁` is the largest calculated singular value.
253+
This is equivalent to the default [`rank(::AbstractMatrix)`](@ref) method except that it re-uses an existing SVD factorization.
254+
`atol` and `rtol` are the absolute and relative tolerances, respectively.
255+
The default relative tolerance is `n*ϵ`, where `n` is the size of the smallest dimension of UΣV'
256+
and `ϵ` is the [`eps`](@ref) of the element type of `S`.
257+
258+
!!! compat "Julia 1.12"
259+
The `rank(::SVD)` method requires at least Julia 1.12.
260+
"""
261+
function rank(S::SVD; atol::Real = 0.0, rtol::Real = (min(size(S)...)*eps(real(float(eltype(S))))))
262+
tol = max(atol, rtol*S.S[1])
263+
count(>(tol), S.S)
264+
end
265+
248266
### SVD least squares ###
249267
function ldiv!(A::SVD{T}, B::AbstractVecOrMat) where T
250268
m, n = size(A)

test/svd.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,4 +294,16 @@ end
294294
@test F.S F32.S
295295
end
296296

297+
@testset "rank svd" begin
298+
# Test that the rank of an svd is computed correctly
299+
@test rank(svd([1.0 0.0; 0.0 1.0])) == 2
300+
@test rank(svd([1.0 0.0; 0.0 0.9]), rtol=0.95) == 1
301+
@test rank(svd([1.0 0.0; 0.0 0.9]), atol=0.95) == 1
302+
@test rank(svd([1.0 0.0; 0.0 1.0]), rtol=1.01) == 0
303+
@test rank(svd([1.0 0.0; 0.0 1.0]), atol=1.01) == 0
304+
305+
@test rank(svd([1.0 2.0; 2.0 4.0])) == 1
306+
@test rank(svd([1.0 2.0 3.0; 4.0 5.0 6.0 ; 7.0 8.0 9.0])) == 2
307+
end
308+
297309
end # module TestSVD

0 commit comments

Comments
 (0)