Skip to content

Commit bd4effd

Browse files
authored
[CUSOLVER] Fix Xgesvdr! (#2556)
1 parent fcf3afe commit bd4effd

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

lib/cusolver/dense_generic.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,21 +327,22 @@ function Xgesvdr!(jobu::Char, jobv::Char, A::StridedCuMatrix{T}, k::Integer;
327327
niters::Integer=2, p::Integer=2*k) where {T <: BlasFloat}
328328
m, n = size(A)
329329
= min(m,n)
330+
p = min(p, ℓ-k) # Ensure that p + k ≤ ℓ
330331
(1 k ℓ) || throw(ArgumentError("illegal choice of parameter k = $k, which must be between 1 and min(m,n) = $ℓ"))
331332
(k+p ℓ) || throw(ArgumentError("illegal choice of parameters k = $k and p = $p, which must satisfy k+p ≤ min(m,n) = $ℓ"))
332333
R = real(T)
333334
U = if jobu == 'S'
334-
CuMatrix{T}(undef, m, k)
335+
CuMatrix{T}(undef, m, m)
335336
elseif jobu == 'N'
336-
CU_NULL
337+
CuMatrix{T}(undef, m, ℓ)
337338
else
338339
throw(ArgumentError("jobu is incorrect. The values accepted are 'S' and 'N'."))
339340
end
340341
Σ = CuVector{R}(undef, ℓ)
341342
V = if jobv == 'S'
342-
CuMatrix{T}(undef, n, k)
343+
CuMatrix{T}(undef, n, n)
343344
elseif jobv == 'N'
344-
CU_NULL
345+
CuMatrix{T}(undef, n, ℓ)
345346
else
346347
throw(ArgumentError("jobv is incorrect. The values accepted are 'S' and 'N'."))
347348
end

test/libraries/cusolver/dense_generic.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -171,22 +171,28 @@ p = 5
171171
@test A collect(U) * Diagonal(collect(Σ)) * collect(V)'
172172
end
173173

174-
# @testset "gesvdr!" begin
175-
# R = real(elty)
176-
# B = rand(elty,m,m)
177-
# FB = qr(B)
178-
# C = rand(elty,n,n)
179-
# FC = qr(C)
180-
# Σ = zeros(elty,m,n)
181-
# for i = 1:n
182-
# Σ[i,i] = i*one(R)
183-
# end
184-
# k = 3
185-
# A = FB.Q * Σ * FC.Q
186-
# d_A = CuMatrix(A)
187-
# U, Σ, V = CUSOLVER.Xgesvdr!('N', 'N', d_A, k)
188-
# @test A ≈ collect(U[:,1:n] * Diagonal(Σ) * V')
189-
# end
174+
@testset "gesvdr!" begin
175+
R = real(elty)
176+
tol = R == Float32 ? 1e-2 : 1e-5
177+
= min(m, n)
178+
179+
B = rand(elty,m,m)
180+
FB = qr(B)
181+
C = rand(elty,n,n)
182+
FC = qr(C)
183+
Σ = zeros(R,m,n)
184+
for i = 1:
185+
Σ[i,i] = (10-i+1)*one(R)
186+
end
187+
A = FB.Q * Σ * FC.Q
188+
d_A = CuMatrix(A)
189+
190+
d_U, d_Σ, d_V = CUSOLVER.Xgesvdr!('N', 'N', d_A, 3)
191+
@test norm(diag(Σ)[1:3] - collect(d_Σ[1:3])) tol
192+
193+
d_U, d_Σ, d_V = CUSOLVER.Xgesvdr!('S', 'S', d_A, ℓ)
194+
@test norm(diag(Σ) - collect(d_Σ)) tol
195+
end
190196

191197
@testset "syevdx!" begin
192198
R = real(elty)

0 commit comments

Comments
 (0)