Skip to content

Commit cd8b771

Browse files
Fix lu/qr pivot deprecations and allow PivotingStrategy input (#1002)
* Remove deprecation due to lu pivot choice change in v1.7 `lu` on larger static arrays now hits a deprecation in Base. * fix `Val` deprecations for `lu` & `qr` and allow new `PivotingStrategy` inputs * fix a test warning Co-authored-by: Christopher Rackauckas <Contact@ChrisRackauckas.com>
1 parent 650729d commit cd8b771

File tree

4 files changed

+54
-15
lines changed

4 files changed

+54
-15
lines changed

src/lu.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,21 @@ function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::LU)
3030
end
3131

3232
# LU decomposition
33-
for pv in (:true, :false)
33+
pivot_options = if isdefined(LinearAlgebra, :PivotingStrategy) # introduced in Julia v1.7
34+
(:(Val{true}), :(Val{false}), :NoPivot, :RowMaximum)
35+
else
36+
(:(Val{true}), :(Val{false}))
37+
end
38+
for pv in pivot_options
3439
# ... define each `pivot::Val{true/false}` method individually to avoid ambiguties
35-
@eval function lu(A::StaticMatrix, pivot::Val{$pv}; check = true)
40+
@eval function lu(A::StaticMatrix, pivot::$pv; check = true)
3641
L, U, p = _lu(A, pivot, check)
3742
LU(L, U, p)
3843
end
3944

4045
# For the square version, return explicit lower and upper triangular matrices.
4146
# We would do this for the rectangular case too, but Base doesn't support that.
42-
@eval function lu(A::StaticMatrix{N,N}, pivot::Val{$pv}; check = true) where {N}
47+
@eval function lu(A::StaticMatrix{N,N}, pivot::$pv; check = true) where {N}
4348
L, U, p = _lu(A, pivot, check)
4449
LU(LowerTriangular(L), UpperTriangular(U), p)
4550
end
@@ -69,18 +74,28 @@ issuccess(F::LU) = _first_zero_on_diagonal(F.U) == 0
6974

7075
@generated function _lu(A::StaticMatrix{M,N,T}, pivot, check) where {M,N,T}
7176
if M*N 14*14
77+
_pivot = if isdefined(LinearAlgebra, :PivotingStrategy) # v1.7 feature
78+
pivot === RowMaximum ? Val(true) : pivot === NoPivot ? Val(false) : pivot()
79+
else
80+
pivot()
81+
end
7282
quote
73-
L, U, P = __lu(A, pivot)
83+
L, U, P = __lu(A, $(_pivot))
7484
if check
7585
i = _first_zero_on_diagonal(U)
7686
i == 0 || throw(SingularException(i))
7787
end
7888
L, U, P
7989
end
8090
else
91+
_pivot = if isdefined(LinearAlgebra, :PivotingStrategy) # v1.7 feature
92+
pivot === Val{true} ? RowMaximum() : pivot === Val{false} ? NoPivot() : pivot()
93+
else
94+
pivot()
95+
end
8196
quote
8297
# call through to Base to avoid excessive time spent on type inference for large matrices
83-
f = lu(Matrix(A), pivot; check = check)
98+
f = lu(Matrix(A), $(_pivot); check = check)
8499
# Trick to get the output eltype - can't rely on the result of f.L as
85100
# it's not type inferrable.
86101
T2 = arithmetic_closure(T)

src/qr.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@ Base.iterate(S::QR, ::Val{:R}) = (S.R, Val(:p))
1111
Base.iterate(S::QR, ::Val{:p}) = (S.p, Val(:done))
1212
Base.iterate(S::QR, ::Val{:done}) = nothing
1313

14-
for pv in (:true, :false)
14+
pivot_options = if isdefined(LinearAlgebra, :PivotingStrategy) # introduced in Julia v1.7
15+
(:(Val{true}), :(Val{false}), :NoPivot, :ColumnNorm)
16+
else
17+
(:(Val{true}), :(Val{false}))
18+
end
19+
for pv in pivot_options
1520
@eval begin
16-
@inline function qr(A::StaticMatrix, pivot::Val{$pv})
21+
@inline function qr(A::StaticMatrix, pivot::$pv)
1722
QRp = _qr(Size(A), A, pivot)
1823
if length(QRp) === 2
1924
# create an identity permutation since that is cheap,
@@ -28,7 +33,8 @@ for pv in (:true, :false)
2833
end
2934
end
3035
"""
31-
qr(A::StaticMatrix, pivot::Union{Val{true}, Val{false}} = Val(false))
36+
qr(A::StaticMatrix,
37+
pivot::Union{Val{true}, Val{false}, LinearAlgebra.PivotingStrategy} = Val(false))
3238
3339
Compute the QR factorization of `A`. The factors can be obtained by iteration:
3440
@@ -58,16 +64,17 @@ end
5864

5965
_qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))
6066

61-
62-
@generated function _qr(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, pivot::Union{Val{false}, Val{true}} = Val(false)) where {sA, TA}
67+
@generated function _qr(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA},
68+
pivot = Val(false)) where {sA, TA}
6369

6470
SizeQ = Size( sA[1], diagsize(Size(A)) )
6571
SizeR = Size( diagsize(Size(A)), sA[2] )
6672

67-
if pivot === Val{true}
73+
if pivot === Val{true} || (isdefined(LinearAlgebra, :PivotingStrategy) && pivot === ColumnNorm)
74+
_pivot = isdefined(LinearAlgebra, :PivotingStrategy) ? ColumnNorm() : Val(true)
6875
return quote
6976
@_inline_meta
70-
Q0, R0, p0 = qr(Matrix(A), pivot)
77+
Q0, R0, p0 = qr(Matrix(A), $(_pivot))
7178
T = _qreltype(TA)
7279
return similar_type(A, T, $(SizeQ))(Matrix(Q0)),
7380
similar_type(A, T, $(SizeR))(R0),
@@ -77,12 +84,13 @@ _qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))
7784
if (sA[1]*sA[1] + sA[1]*sA[2])÷2 * diagsize(Size(A)) < 17*17*17
7885
return quote
7986
@_inline_meta
80-
return qr_unrolled(Size(A), A, pivot)
87+
return qr_unrolled(Size(A), A, Val(false))
8188
end
8289
else
90+
_pivot = isdefined(LinearAlgebra, :PivotingStrategy) ? NoPivot() : Val(false)
8391
return quote
8492
@_inline_meta
85-
Q0R0 = qr(Matrix(A), pivot)
93+
Q0R0 = qr(Matrix(A), $(_pivot))
8694
Q0, R0 = Matrix(Q0R0.Q), Q0R0.R
8795
T = _qreltype(TA)
8896
return similar_type(A, T, $(SizeQ))(Q0),

test/lu.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,12 @@ end
7979
end
8080
end
8181

82+
if isdefined(LinearAlgebra, :PivotingStrategy)
83+
for N = (3, 15)
84+
A = (@SMatrix randn(N,N))
85+
@test lu(A, Val(false)) == lu(A, NoPivot())
86+
@test lu(A, Val(true)) == lu(A, RowMaximum())
87+
end
88+
end
89+
8290
end # @testset "LU"

test/qr.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Random.seed!(42)
2727

2828
# pivot=true cases has no StaticArrays specific version yet
2929
# but fallbacks to LAPACK
30-
pivot = Val(true)
30+
pivot = isdefined(LinearAlgebra, :PivotingStrategy) ? ColumnNorm() : Val(true)
3131
QRp = @inferred qr(arr, pivot)
3232
@test QRp isa StaticArrays.QR
3333
Q, R, p = QRp
@@ -59,6 +59,14 @@ Random.seed!(42)
5959
]
6060
test_qr(arr)
6161
end
62+
63+
if isdefined(LinearAlgebra, :PivotingStrategy)
64+
for N = (3, 18)
65+
A = (@SMatrix randn(N,N))
66+
@test qr(A, Val(false)) == qr(A, NoPivot())
67+
@test qr(A, Val(true)) == qr(A, ColumnNorm())
68+
end
69+
end
6270
end
6371

6472
@testset "QR method ambiguity" begin

0 commit comments

Comments
 (0)