Skip to content

Commit 609aa34

Browse files
authored
Add direct calls to BLAS to compute SVDs (#1259)
* Add direct calls to BLAS to compute SVD * Move funcs with direct BLAS interface to new file * Use @CCall * Restrict BLAS interface to Julia v1.7 or higher * Bump version to 1.9.5
1 parent c4092a1 commit 609aa34

File tree

5 files changed

+168
-4
lines changed

5 files changed

+168
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.9.4"
3+
version = "1.9.5"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/StaticArrays.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ include("flatten.jl")
133133
include("io.jl")
134134
include("pinv.jl")
135135

136+
@static if VERSION >= v"1.7"
137+
include("blas.jl")
138+
end
139+
136140
@static if !isdefined(Base, :get_extension) # VERSION < v"1.9-"
137141
include("../ext/StaticArraysStatisticsExt.jl")
138142
end

src/blas.jl

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# This file contains funtions that uses a direct interface to BLAS library. We use this
2+
# approach to reduce allocations.
3+
4+
import LinearAlgebra: BLAS, LAPACK, libblastrampoline
5+
6+
# == Singular Value Decomposition ==========================================================
7+
8+
# Implement direct call to BLAS functions that computes the SVD values for `SMatrix` and
9+
# `MMatrix` reducing allocations. In this case, we use `MMatrix` to call the library and
10+
# convert the result back to the input type. Since the former does not exit this scope, we
11+
# can reduce allocations.
12+
#
13+
# We are implementing here the following functions:
14+
#
15+
# svdvals(A::SMatrix{M, N, Float64}) where {M, N}
16+
# svdvals(A::SMatrix{M, N, Float32}) where {M, N}
17+
# svdvals(A::MMatrix{M, N, Float64}) where {M, N}
18+
# svdvals(A::MMatrix{M, N, Float32}) where {M, N}
19+
#
20+
for (gesdd, elty) in ((:dgesdd_, :Float64), (:sgesdd_, :Float32)),
21+
(mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector))
22+
23+
blas_func = @eval BLAS.@blasfunc($gesdd)
24+
25+
@eval begin
26+
function svdvals(A::$mtype{M, N, $elty}) where {M, N}
27+
K = min(M, N)
28+
29+
# Convert the input to a `MMatrix` and allocate the required arrays.
30+
Am = MMatrix{M, N, $elty}(A)
31+
Sm = MVector{K, $elty}(undef)
32+
33+
# We compute the `lwork` (size of the work array) by obtaining the maximum value
34+
# from the possibilities shown in:
35+
# https://docs.oracle.com/cd/E19422-01/819-3691/dgesdd.html
36+
lwork = max(8N, 3N + max(M, 7N), 8M, 3M + max(N, 7M))
37+
work = MVector{lwork, $elty}(undef)
38+
iwork = MVector{8min(M, N), BLAS.BlasInt}(undef)
39+
info = Ref(1)
40+
41+
@ccall libblastrampoline.$blas_func(
42+
'N'::Ref{UInt8},
43+
M::Ref{BLAS.BlasInt},
44+
N::Ref{BLAS.BlasInt},
45+
Am::Ptr{$elty},
46+
M::Ref{BLAS.BlasInt},
47+
Sm::Ptr{$elty},
48+
C_NULL::Ptr{C_NULL},
49+
M::Ref{BLAS.BlasInt},
50+
C_NULL::Ptr{C_NULL},
51+
K::Ref{BLAS.BlasInt},
52+
work::Ptr{$elty},
53+
lwork::Ref{BLAS.BlasInt},
54+
iwork::Ptr{BLAS.BlasInt},
55+
info::Ptr{BLAS.BlasInt},
56+
1::Clong
57+
)::Cvoid
58+
59+
# Check if the return result of the function.
60+
LAPACK.chklapackerror(info.x)
61+
62+
# Convert the vector to static arrays and return.
63+
S = $vtype{K, $elty}(Sm)
64+
65+
return S
66+
end
67+
end
68+
end
69+
70+
# For matrices with interger numbers, we should promote them to float and call `svdvals`.
71+
@inline svdvals(A::StaticMatrix{<: Any, <: Any, <: Integer}) = svdvals(float(A))
72+
73+
# Implement direct call to BLAS functions that computes the SVD for `SMatrix` and `MMatrix`
74+
# reducing allocations. In this case, we use `MMatrix` to call the library and convert the
75+
# result back to the input type. Since the former does not exit this scope, we can reduce
76+
# allocations.
77+
#
78+
# We are implementing here the following functions:
79+
#
80+
# _svd(A::SMatrix{M, N, Float64}, full::Val{false}) where {M, N}
81+
# _svd(A::SMatrix{M, N, Float64}, full::Val{true}) where {M, N}
82+
# _svd(A::SMatrix{M, N, Float32}, full::Val{false}) where {M, N}
83+
# _svd(A::SMatrix{M, N, Float32}, full::Val{true}) where {M, N}
84+
# _svd(A::MMatrix{M, N, Float64}, full::Val{false}) where {M, N}
85+
# _svd(A::MMatrix{M, N, Float64}, full::Val{true}) where {M, N}
86+
# _svd(A::MMatrix{M, N, Float32}, full::Val{false}) where {M, N}
87+
# _svd(A::MMatrix{M, N, Float32}, full::Val{true}) where {M, N}
88+
#
89+
for (gesvd, elty) in ((:dgesvd_, :Float64), (:sgesvd_, :Float32)),
90+
full in (false, true),
91+
(mtype, vtype) in ((SMatrix, SVector), (MMatrix, MVector))
92+
93+
blas_func = @eval BLAS.@blasfunc($gesvd)
94+
95+
@eval begin
96+
function _svd(A::$mtype{M, N, $elty}, full::Val{$full}) where {M, N}
97+
K = min(M, N)
98+
99+
# Convert the input to a `MMatrix` and allocate the required arrays.
100+
Am = MMatrix{M, N, $elty}(A)
101+
Um = MMatrix{M, $(full ? :M : :K), $elty}(undef)
102+
Sm = MVector{K, $elty}(undef)
103+
Vtm = MMatrix{$(full ? :N : :K), N, $elty}(undef)
104+
lwork = max(3min(M, N) + max(M, N), 5min(M, N))
105+
work = MVector{lwork, $elty}(undef)
106+
info = Ref(1)
107+
108+
@ccall libblastrampoline.$blas_func(
109+
$(full ? 'A' : 'S')::Ref{UInt8},
110+
$(full ? 'A' : 'S')::Ref{UInt8},
111+
M::Ref{BLAS.BlasInt},
112+
N::Ref{BLAS.BlasInt},
113+
Am::Ptr{$elty},
114+
M::Ref{BLAS.BlasInt},
115+
Sm::Ptr{$elty},
116+
Um::Ptr{$elty},
117+
M::Ref{BLAS.BlasInt},
118+
Vtm::Ptr{$elty},
119+
$(full ? :N : :K)::Ref{BLAS.BlasInt},
120+
work::Ptr{$elty},
121+
lwork::Ref{BLAS.BlasInt},
122+
info::Ptr{BLAS.BlasInt},
123+
1::Clong,
124+
1::Clong
125+
)::Cvoid
126+
127+
# Check if the return result of the function.
128+
LAPACK.chklapackerror(info.x)
129+
130+
# Convert the matrices to the correct type and return.
131+
U = $mtype{M, $(full ? :M : :K), $elty}(Um)
132+
S = $vtype{K, $elty}(Sm)
133+
Vt = $mtype{$(full ? :N : :K), N, $elty}(Vtm)
134+
135+
return SVD(U, S, Vt)
136+
end
137+
end
138+
end
139+
140+
# For matrices with interger numbers, we should promote them to float and call `svd`.
141+
@inline svd(A::StaticMatrix{<: Any, <: Any, <: Integer}) = svd(float(A))

src/svd.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,4 @@ function diagmult(sd, sB, d, B)
7373
ind = SOneTo(sd[1])
7474
return isa(B, AbstractVector) ? Diagonal(d)*B[ind] : Diagonal(d)*B[ind,:]
7575
end
76+

test/svd.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ using StaticArrays, Test, LinearAlgebra
22

33
@testset "SVD factorization" begin
44
m3 = @SMatrix Float64[3 9 4; 6 6 2; 3 7 9]
5+
m3_f32 = @SMatrix Float32[3 9 4; 6 6 2; 3 7 9]
56
m3c = ComplexF64.(m3)
67
m23 = @SMatrix Float64[3 9 4; 6 6 2]
8+
m23_f32 = @SMatrix Float32[3 9 4; 6 6 2]
79
m_sing = @SMatrix [2.0 3.0 5.0; 4.0 6.0 10.0; 1.0 1.0 1.0]
810
m_sing2 = @SMatrix [1 1; 1 0; 0 1]
911
v = @SVector [1, 2, 3]
@@ -18,16 +20,22 @@ using StaticArrays, Test, LinearAlgebra
1820
@testinf svdvals((@SMatrix [2 -2; 1 1]) / sqrt(2)) [2, 1]
1921

2022
@testinf svdvals(m3) svdvals(Matrix(m3))
23+
@testinf svdvals(m3_f32) svdvals(Matrix(m3_f32))
2124
@testinf svdvals(m3c) isa SVector{3,Float64}
2225

2326
@testinf svd(m3).U::StaticMatrix svd(Matrix(m3)).U
2427
@testinf svd(m3).S::StaticVector svd(Matrix(m3)).S
2528
@testinf svd(m3).V::StaticMatrix svd(Matrix(m3)).V
2629
@testinf svd(m3).Vt::StaticMatrix svd(Matrix(m3)).Vt
2730

28-
@testinf svd(@SMatrix [2 0; 0 0]).U === one(SMatrix{2,2})
29-
@testinf svd(@SMatrix [2 0; 0 0]).S === SVector(2.0, 0.0)
30-
@testinf svd(@SMatrix [2 0; 0 0]).Vt === one(SMatrix{2,2})
31+
@test svd(m3_f32).U::StaticMatrix svd(Matrix(m3_f32)).U atol = 5e-7
32+
@test svd(m3_f32).S::StaticVector svd(Matrix(m3_f32)).S atol = 5e-7
33+
@test svd(m3_f32).V::StaticMatrix svd(Matrix(m3_f32)).V atol = 5e-7
34+
@test svd(m3_f32).Vt::StaticMatrix svd(Matrix(m3_f32)).Vt atol = 5e-7
35+
36+
@testinf svd(@SMatrix [2 0; 0 0]).U one(SMatrix{2,2})
37+
@testinf svd(@SMatrix [2 0; 0 0]).S SVector(2.0, 0.0)
38+
@testinf svd(@SMatrix [2 0; 0 0]).Vt one(SMatrix{2,2})
3139

3240
@testinf svd((@SMatrix [2 -2; 1 1]) / sqrt(2)).U [-1 0; 0 1]
3341
@testinf svd((@SMatrix [2 -2; 1 1]) / sqrt(2)).S [2, 1]
@@ -41,6 +49,16 @@ using StaticArrays, Test, LinearAlgebra
4149
@testinf svd(m23').S svd(Matrix(m23')).S
4250
@testinf svd(m23').Vt svd(Matrix(m23')).Vt
4351

52+
@test svd(m23_f32).U::StaticMatrix svd(Matrix(m23_f32)).U atol = 5e-7
53+
@test svd(m23_f32).S::StaticVector svd(Matrix(m23_f32)).S atol = 5e-7
54+
@test svd(m23_f32).V::StaticMatrix svd(Matrix(m23_f32)).V atol = 5e-7
55+
@test svd(m23_f32).Vt::StaticMatrix svd(Matrix(m23_f32)).Vt atol = 5e-7
56+
57+
@test svd(m23_f32').U::StaticMatrix svd(Matrix(m23_f32')).U atol = 5e-7
58+
@test svd(m23_f32').S::StaticVector svd(Matrix(m23_f32')).S atol = 5e-7
59+
@test svd(m23_f32').V::StaticMatrix svd(Matrix(m23_f32')).V atol = 5e-7
60+
@test svd(m23_f32').Vt::StaticMatrix svd(Matrix(m23_f32')).Vt atol = 5e-7
61+
4462
@testinf svd(m23, full=true).U::StaticMatrix svd(Matrix(m23), full=true).U
4563
@testinf svd(m23, full=true).S::StaticVector svd(Matrix(m23), full=true).S
4664
@testinf svd(m23, full=true).Vt::StaticMatrix svd(Matrix(m23), full=true).Vt

0 commit comments

Comments
 (0)