Skip to content

Commit 10b9b4c

Browse files
committed
Add QR decomposition without pivoting
1 parent 715fefe commit 10b9b4c

File tree

2 files changed

+211
-19
lines changed

2 files changed

+211
-19
lines changed

src/qr.jl

Lines changed: 193 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,200 @@ _thin_must_hold(thin) =
22
thin || throw(ArgumentError("For the sake of type stability, `thin = true` must hold."))
33
import Base.qr
44

5-
function qr(A::StaticMatrix, pivot::Type{Val{true}}; thin::Bool=true)
5+
6+
@inline function qr(A::StaticMatrix, pivot::Type{Val{true}}; thin::Bool=true)
67
_thin_must_hold(thin)
7-
Q0, R0, p0 = Base.qr(Matrix(A), pivot)
8-
T = arithmetic_closure(eltype(A))
9-
QT = similar_type(A, T, Size(diagsize(A), diagsize(A)))
10-
RT = similar_type(A, T)
11-
PT = similar_type(A, Int, Size(Size(A)[2]))
12-
QT(Q0), RT(R0), PT(p0)
8+
return qr(Size(A), A, pivot, Val{true})
9+
end
10+
11+
@generated function qr(SA::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Type{Val{true}}, thin::Union{Type{Val{false}}, Type{Val{true}}}) where {sA, TA}
12+
mQ = nQ = mR = sA[1]
13+
nR = sA[2]
14+
sA[1] > sA[2] && (mR = sA[2])
15+
sA[1] > sA[2] && thin <: Type{Val{true}} && (nQ = sA[2])
16+
T = arithmetic_closure(TA)
17+
QT = similar_type(A, T, Size(mQ, nQ))
18+
RT = similar_type(A, T, Size(mR, nR))
19+
PT = similar_type(A, Int, Size(sA[2]))
20+
return quote
21+
@_inline_meta
22+
Q0, R0, p0 = Base.qr(Matrix(A), pivot)
23+
return $QT(Q0), $RT(R0), $PT(p0)
24+
end
1325
end
14-
function qr(A::StaticMatrix, pivot::Type{Val{false}}; thin::Bool=true)
26+
27+
28+
@inline function qr(A::StaticMatrix, pivot::Type{Val{false}}; thin::Bool=true)
1529
_thin_must_hold(thin)
16-
Q0, R0 = Base.qr(Matrix(A), pivot)
17-
T = arithmetic_closure(eltype(A))
18-
QT = similar_type(A, T, Size(diagsize(A), diagsize(A)))
19-
RT = similar_type(A, T)
20-
QT(Q0), RT(R0)
30+
return qr(Size(A), A, pivot, Val{true})
31+
end
32+
33+
@generated function qr(SA::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Type{Val{false}}, thin::Union{Type{Val{false}}, Type{Val{true}}}) where {sA, TA}
34+
if sA[1] < 17 && sA[2] < 17
35+
return quote
36+
@_inline_meta
37+
return qr_householder_unrolled(SA, A, thin)
38+
end
39+
else
40+
mQ = nQ = mR = sA[1]
41+
nR = sA[2]
42+
sA[1] > sA[2] && (mR = sA[2])
43+
sA[1] > sA[2] && thin <: Type{Val{true}} && (nQ = sA[2])
44+
T = arithmetic_closure(TA)
45+
QT = similar_type(A, T, Size(mQ, nQ))
46+
RT = similar_type(A, T, Size(mR, nR))
47+
return quote
48+
@_inline_meta
49+
Q0, R0 = Base.qr(Matrix(A), pivot)
50+
return $QT(Q0), $RT(R0)
51+
end
52+
end
53+
end
54+
55+
56+
# Compute the QR decomposition of `A` such that `A = Q*R`
57+
# by Householder reflections without pivoting.
58+
#
59+
# `thin=true` (reduced) method will produce `Q` and `R` in truncated form,
60+
# in the case of `thin=false` Q is full, but R is still reduced, see [`qr`](@ref).
61+
#
62+
# For original source code see below.
63+
@generated function qr_householder_unrolled(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, thin::Union{Type{Val{false}},Type{Val{true}}}) where {sA, TA}
64+
mQ = nQ = mR = m = sA[1]
65+
nR = n = sA[2]
66+
# truncate Q and R for thin case
67+
m > n && (mR = n)
68+
m > n && thin <: Type{Val{true}} && (nQ = n)
69+
70+
Q = [Symbol("Q_$(i)_$(j)") for i = 1:m, j = 1:m]
71+
R = [Symbol("R_$(i)_$(j)") for i = 1:m, j = 1:n]
72+
73+
initQ = [:($(Q[i, j]) = $(i == j ? one : zero)(T)) for i = 1:m, j = 1:m] # Q .= eye(A)
74+
initR = [:($(R[i, j]) = T(A[$i, $j])) for i = 1:m, j = 1:n] # R .= A
75+
76+
code = quote end
77+
for k = 1:min(m - 1 + !(TA<:Real), n)
78+
#x = view(R, k:m, k)
79+
#τk = reflector!(x)
80+
push!(code.args, :(ξ1 = $(R[k, k])))
81+
ex = :(normu = abs2(ξ1))
82+
for i = k+1:m
83+
ex = :($ex + abs2($(R[i, k])))
84+
end
85+
push!(code.args, :(normu = sqrt($ex)))
86+
push!(code.args, :(ν = copysign(normu, real(ξ1))))
87+
push!(code.args, :(ξ1 += ν))
88+
push!(code.args, :(invξ1 = ξ1 == zero(T) ? zero(T) : inv(ξ1)))
89+
push!(code.args, :($(R[k, k]) = -ν))
90+
for i = k+1:m
91+
push!(code.args, :($(R[i, k]) *= invξ1))
92+
end
93+
push!(code.args, :(τk = ν == zero(T) ? zero(T) : ξ1/ν))
94+
95+
#reflectorApply!(x, τk, view(R, k:m, k+1:n))
96+
for j = k+1:n
97+
ex = :($(R[k, j]))
98+
for i = k+1:m
99+
ex = :($ex + $(R[i, k])'*$(R[i, j]))
100+
end
101+
push!(code.args, :(vRj = τk'*$ex))
102+
push!(code.args, :($(R[k, j]) -= vRj))
103+
for i = k+1:m
104+
push!(code.args, :($(R[i, j]) -= $(R[i, k])*vRj))
105+
end
106+
end
107+
108+
#reflectorApplyRight!(x, τk, view(Q, 1:m, k:m))
109+
for i = 1:m
110+
ex = :($(Q[i, k]))
111+
for j = k+1:m
112+
ex = :($ex + $(Q[i, j])*$(R[j, k]))
113+
end
114+
push!(code.args, :(Qiv = $ex*τk))
115+
push!(code.args, :($(Q[i, k]) -= Qiv))
116+
for j = k+1:m
117+
push!(code.args, :($(Q[i, j]) -= Qiv*$(R[j, k])'))
118+
end
119+
end
120+
121+
for i = k+1:m
122+
push!(code.args, :($(R[i, k]) = zero(T)))
123+
end
124+
end
125+
126+
return quote
127+
@_inline_meta
128+
T = arithmetic_closure(TA)
129+
@inbounds $(Expr(:block, initQ...))
130+
@inbounds $(Expr(:block, initR...))
131+
@inbounds $code
132+
@inbounds return similar_type(A, T, $(Size(mQ,nQ)))(tuple($(Q[1:mQ,1:nQ]...))),
133+
similar_type(A, T, $(Size(mR,nR)))(tuple($(R[1:mR,1:nR]...)))
134+
end
135+
21136
end
137+
138+
## source for @generated function above
139+
## derived from base/linalg/qr.jl
140+
## thin version of QR
141+
#function qr_householder_unrolled(A::StaticMatrix{<:Any, <:Any, TA}) where {TA}
142+
# m, n = size(A)
143+
# T = arithmetic_closure(TA)
144+
# Q = eye(MMatrix{m,m,T,m*m})
145+
# R = MMatrix{m,n,T,m*n}(A)
146+
# for k = 1:min(m - 1 + !(TA<:Real), n)
147+
# #x = view(R, k:m, k)
148+
# #τk = reflector!(x)
149+
# ξ1 = R[k, k]
150+
# normu = abs2(ξ1)
151+
# for i = k+1:m
152+
# normu += abs2(R[i, k])
153+
# end
154+
# normu = sqrt(normu)
155+
# ν = copysign(normu, real(ξ1))
156+
# ξ1 += ν
157+
# invξ1 = ξ1 == zero(T) ? zero(T) : inv(ξ1)
158+
# R[k, k] = -ν
159+
# for i = k+1:m
160+
# R[i, k] *= invξ1
161+
# end
162+
# τk = ν == zero(T) ? zero(T) : ξ1/ν
163+
#
164+
# #reflectorApply!(x, τk, view(R, k:m, k+1:n))
165+
# for j = k+1:n
166+
# vRj = R[k, j]
167+
# for i = k+1:m
168+
# vRj += R[i, k]'*R[i, j]
169+
# end
170+
# vRj = τk'*vRj
171+
# R[k, j] -= vRj
172+
# for i = k+1:m
173+
# R[i, j] -= R[i, k]*vRj
174+
# end
175+
# end
176+
#
177+
# #reflectorApplyRight!(x, τk, view(Q, 1:m, k:m))
178+
# for i = 1:m
179+
# Qiv = Q[i, k]
180+
# for j = k+1:m
181+
# Qiv += Q[i, j]*R[j, k]
182+
# end
183+
# Qiv = Qiv*τk
184+
# Q[i, k] -= Qiv
185+
# for j = k+1:m
186+
# Q[i, j] -= Qiv*R[j, k]'
187+
# end
188+
# end
189+
#
190+
# for i = k+1:m
191+
# R[i, k] = zero(T)
192+
# end
193+
#
194+
# end
195+
# if m > n
196+
# return (similar_type(A, T, Size(m, n))(Q[1:m,1:n]), similar_type(A, T, Size(n, n))(R[1:n,1:n]))
197+
# else
198+
# return (similar_type(A, T, Size(m, m))(Q), similar_type(A, T, Size(n, n))(R))
199+
# end
200+
#end
201+

test/qr.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@testset "qr" begin
1+
@testset "QR decomposition" begin
22
function test_qr(arr)
33
QR = @inferred qr(arr)
44
@test QR isa Tuple
@@ -8,14 +8,19 @@
88
@test R isa StaticMatrix
99

1010
Q_ref,R_ref = qr(Matrix(arr))
11-
@test Q Q_ref
12-
@test R R_ref
11+
@test abs.(Q) abs.(Q_ref) # QR is unique up to diag(Q) signs
12+
@test abs.(R) abs.(R_ref)
13+
@test Q*R arr
14+
@test Q'*Q eye(Q'*Q)
15+
@test istriu(R)
16+
@test abs.(Q_ref'*Q) eye(Q_ref'*Q)
17+
@test Q_ref'*Q * R R_ref
1318

1419
pivot = Val{true}
1520
QRp = @inferred qr(arr,pivot)
1621
@test QRp isa Tuple
1722
@test length(QRp) == 3
18-
Q, R, p= QRp
23+
Q, R, p = QRp
1924
@test Q isa StaticMatrix
2025
@test R isa StaticMatrix
2126
@test p isa StaticVector
@@ -30,10 +35,17 @@
3035

3136
for arr in [
3237
(@MMatrix randn(2,2)),
38+
[(@SMatrix randn(i, j)) for i=3:3 for j=max(i-1,1):i+1]...,
3339
(@SMatrix randn(10, 12)),
40+
(@SMatrix([0 1 2; 0 2 3; 0 3 4; 0 4 5])),
41+
(@SMatrix zeros(Int,5,5)),
42+
map(Complex128, @SMatrix randn(2,3)),
43+
map(Complex128, @SMatrix randn(3,2)),
3444
map(BigFloat, @SMatrix randn(1,1)),
35-
(@SMatrix zeros(Int,10,10)),
36-
map(Complex128, @MMatrix randn(2,3))
45+
map(Complex{BigFloat}, @MMatrix randn(2,3)),
46+
map(Complex{BigFloat}, @SMatrix randn(3,2)),
47+
(@SMatrix randn(19, 2)),
48+
(@SMatrix randn(2, 19))
3749
]
3850
test_qr(arr)
3951
end

0 commit comments

Comments
 (0)