Skip to content

Commit a3f715e

Browse files
committed
Fix QR wrappers and interfaces, tests enhancement
1 parent 10b9b4c commit a3f715e

File tree

2 files changed

+93
-62
lines changed

2 files changed

+93
-62
lines changed

src/qr.jl

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,51 +3,49 @@ _thin_must_hold(thin) =
33
import Base.qr
44

55

6-
@inline function qr(A::StaticMatrix, pivot::Type{Val{true}}; thin::Bool=true)
6+
@inline function qr(A::StaticMatrix, pivot::Union{Type{Val{false}}, Type{Val{true}}} = Val{false}; thin::Bool=true)
77
_thin_must_hold(thin)
88
return qr(Size(A), A, pivot, Val{true})
99
end
1010

11-
@generated function qr(SA::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Type{Val{true}}, thin::Union{Type{Val{false}}, Type{Val{true}}}) where {sA, TA}
12-
mQ = nQ = mR = sA[1]
13-
nR = sA[2]
14-
sA[1] > sA[2] && (mR = sA[2])
15-
sA[1] > sA[2] && thin <: Type{Val{true}} && (nQ = sA[2])
16-
T = arithmetic_closure(TA)
17-
QT = similar_type(A, T, Size(mQ, nQ))
18-
RT = similar_type(A, T, Size(mR, nR))
19-
PT = similar_type(A, Int, Size(sA[2]))
20-
return quote
21-
@_inline_meta
22-
Q0, R0, p0 = Base.qr(Matrix(A), pivot)
23-
return $QT(Q0), $RT(R0), $PT(p0)
24-
end
25-
end
2611

12+
"""
13+
qr(Size(A), A::StaticMatrix, pivot=Val{false}, thin=Val{true}) -> Q, R, [p]
2714
28-
@inline function qr(A::StaticMatrix, pivot::Type{Val{false}}; thin::Bool=true)
29-
_thin_must_hold(thin)
30-
return qr(Size(A), A, pivot, Val{true})
31-
end
15+
Compute the QR factorization of `A` such that `A = Q*R` or `A[:,p] = Q*R`, see [`qr`](@ref).
16+
This function is exported to allow bypass the type instability problem in base `qr` function
17+
with keyword `thin` parameter in the interface.
18+
"""
19+
@generated function qr(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Union{Type{Val{false}}, Type{Val{true}}} = Val{false}, thin::Union{Type{Val{false}}, Type{Val{true}}} = Val{true}) where {sA, TA}
20+
21+
isthin = thin <: Type{Val{true}}
22+
23+
SizeQ = Size( sA[1], isthin ? diagsize(Size(A)) : sA[1] )
24+
SizeR = Size( diagsize(Size(A)), sA[2] )
3225

33-
@generated function qr(SA::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Type{Val{false}}, thin::Union{Type{Val{false}}, Type{Val{true}}}) where {sA, TA}
34-
if sA[1] < 17 && sA[2] < 17
26+
if pivot <: Type{Val{true}}
3527
return quote
3628
@_inline_meta
37-
return qr_householder_unrolled(SA, A, thin)
29+
Q0, R0, p0 = Base.qr(Matrix(A), pivot, thin=$isthin)
30+
T = arithmetic_closure(TA)
31+
return similar_type(A, T, $(SizeQ))(Q0),
32+
similar_type(A, T, $(SizeR))(R0),
33+
similar_type(A, Int, $(Size(sA[2])))(p0)
3834
end
3935
else
40-
mQ = nQ = mR = sA[1]
41-
nR = sA[2]
42-
sA[1] > sA[2] && (mR = sA[2])
43-
sA[1] > sA[2] && thin <: Type{Val{true}} && (nQ = sA[2])
44-
T = arithmetic_closure(TA)
45-
QT = similar_type(A, T, Size(mQ, nQ))
46-
RT = similar_type(A, T, Size(mR, nR))
47-
return quote
48-
@_inline_meta
49-
Q0, R0 = Base.qr(Matrix(A), pivot)
50-
return $QT(Q0), $RT(R0)
36+
if (sA[1]*sA[1] + sA[1]*sA[2])÷2 * diagsize(Size(A)) < 17*17*17
37+
return quote
38+
@_inline_meta
39+
return qr_unrolled(Size(A), A, thin)
40+
end
41+
else
42+
return quote
43+
@_inline_meta
44+
Q0, R0 = Base.qr(Matrix(A), pivot, thin=$isthin)
45+
T = arithmetic_closure(TA)
46+
return similar_type(A, T, $(SizeQ))(Q0),
47+
similar_type(A, T, $(SizeR))(R0)
48+
end
5149
end
5250
end
5351
end
@@ -60,12 +58,8 @@ end
6058
# in the case of `thin=false` Q is full, but R is still reduced, see [`qr`](@ref).
6159
#
6260
# For original source code see below.
63-
@generated function qr_householder_unrolled(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, thin::Union{Type{Val{false}},Type{Val{true}}}) where {sA, TA}
64-
mQ = nQ = mR = m = sA[1]
65-
nR = n = sA[2]
66-
# truncate Q and R for thin case
67-
m > n && (mR = n)
68-
m > n && thin <: Type{Val{true}} && (nQ = n)
61+
@generated function qr_unrolled(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, thin::Union{Type{Val{false}},Type{Val{true}}} = Val{true}) where {sA, TA}
62+
m, n = sA[1], sA[2]
6963

7064
Q = [Symbol("Q_$(i)_$(j)") for i = 1:m, j = 1:m]
7165
R = [Symbol("R_$(i)_$(j)") for i = 1:m, j = 1:n]
@@ -123,22 +117,31 @@ end
123117
end
124118
end
125119

120+
# truncate Q and R sizes in LAPACK consilient way
121+
if thin <: Type{Val{true}}
122+
mQ, nQ = m, min(m, n)
123+
else
124+
mQ, nQ = m, m
125+
end
126+
mR, nR = min(m, n), n
127+
126128
return quote
127129
@_inline_meta
128130
T = arithmetic_closure(TA)
129131
@inbounds $(Expr(:block, initQ...))
130132
@inbounds $(Expr(:block, initR...))
131133
@inbounds $code
132-
@inbounds return similar_type(A, T, $(Size(mQ,nQ)))(tuple($(Q[1:mQ,1:nQ]...))),
133-
similar_type(A, T, $(Size(mR,nR)))(tuple($(R[1:mR,1:nR]...)))
134+
@inbounds return similar_type(A, T, $(Size(mQ, nQ)))( tuple($(Q[1:mQ, 1:nQ]...)) ),
135+
similar_type(A, T, $(Size(mR, nR)))( tuple($(R[1:mR, 1:nR]...)) )
134136
end
135137

136138
end
137139

138-
## source for @generated function above
139-
## derived from base/linalg/qr.jl
140-
## thin version of QR
141-
#function qr_householder_unrolled(A::StaticMatrix{<:Any, <:Any, TA}) where {TA}
140+
141+
## Source for @generated qr_unrolled() function above.
142+
## Derived from base/linalg/qr.jl
143+
## thin=true version of QR
144+
#function qr_unrolled(A::StaticMatrix{<:Any, <:Any, TA}) where {TA}
142145
# m, n = size(A)
143146
# T = arithmetic_closure(TA)
144147
# Q = eye(MMatrix{m,m,T,m*m})

test/qr.jl

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
1+
using StaticArrays, Base.Test
2+
3+
Base.randn(::Type{BigFloat}) = BigFloat(randn(Float64))
4+
Base.randn(::Type{BigFloat}, I::Integer) = [randn(BigFloat) for i=1:I]
5+
Base.randn(::Type{Int}) = rand(-9:9)
6+
Base.randn(::Type{Int}, I::Integer) = [randn(Int) for i=1:I]
7+
Base.randn(::Type{Complex{T}}) where T = Complex{T}(randn(T,2)...)
8+
Base.randn(::Type{Complex}) = randn(Complex{Float64})
9+
10+
srand(42)
111
@testset "QR decomposition" begin
212
function test_qr(arr)
13+
14+
# thin=true case
315
QR = @inferred qr(arr)
416
@test QR isa Tuple
517
@test length(QR) == 2
@@ -13,11 +25,26 @@
1325
@test Q*R arr
1426
@test Q'*Q eye(Q'*Q)
1527
@test istriu(R)
16-
@test abs.(Q_ref'*Q) eye(Q_ref'*Q)
17-
@test Q_ref'*Q * R R_ref
1828

29+
# fat (thin=false) case
30+
QR = @inferred qr(Size(arr), arr, Val{false}, Val{false})
31+
@test QR isa Tuple
32+
@test length(QR) == 2
33+
Q, R = QR
34+
@test Q isa StaticMatrix
35+
@test R isa StaticMatrix
36+
37+
Q_ref,R_ref = qr(Matrix(arr), thin=false)
38+
@test abs.(Q) abs.(Q_ref) # QR is unique up to diag(Q) signs
39+
@test abs.(R) abs.(R_ref)
40+
R0 = vcat(R, @SMatrix(zeros(size(arr)[1]-size(R)[1], size(R)[2])) )
41+
@test Q*R0 arr
42+
@test Q'*Q eye(Q'*Q)
43+
@test istriu(R)
44+
45+
# pivot=true cases are not released yet
1946
pivot = Val{true}
20-
QRp = @inferred qr(arr,pivot)
47+
QRp = @inferred qr(arr, pivot)
2148
@test QRp isa Tuple
2249
@test length(QRp) == 3
2350
Q, R, p = QRp
@@ -33,20 +60,21 @@
3360

3461
@test_throws ArgumentError qr(@SMatrix randn(1,2); thin=false)
3562

63+
for eltya in (Float32, Float64, BigFloat, Int),
64+
rel in (real, complex),
65+
sz in [(3,3), (3,4), (4,3)]
66+
arr = SMatrix{sz[1], sz[2], rel(eltya), sz[1]*sz[2]}( [randn(rel(eltya)) for i = 1:sz[1], j = 1:sz[2]] )
67+
test_qr(arr)
68+
end
69+
# some special cases
3670
for arr in [
37-
(@MMatrix randn(2,2)),
38-
[(@SMatrix randn(i, j)) for i=3:3 for j=max(i-1,1):i+1]...,
39-
(@SMatrix randn(10, 12)),
40-
(@SMatrix([0 1 2; 0 2 3; 0 3 4; 0 4 5])),
41-
(@SMatrix zeros(Int,5,5)),
42-
map(Complex128, @SMatrix randn(2,3)),
43-
map(Complex128, @SMatrix randn(3,2)),
44-
map(BigFloat, @SMatrix randn(1,1)),
45-
map(Complex{BigFloat}, @MMatrix randn(2,3)),
46-
map(Complex{BigFloat}, @SMatrix randn(3,2)),
47-
(@SMatrix randn(19, 2)),
48-
(@SMatrix randn(2, 19))
49-
]
71+
(@MMatrix randn(3,2)),
72+
(@MMatrix randn(2,3)),
73+
(@SMatrix([0 1 2; 0 2 3; 0 3 4; 0 4 5])),
74+
(@SMatrix zeros(Int,4,4)),
75+
(@SMatrix randn(17,18)), # fallback to LAPACK
76+
(@SMatrix randn(18,17))
77+
]
5078
test_qr(arr)
5179
end
5280
end

0 commit comments

Comments
 (0)