Skip to content

Commit bb73490

Browse files
Add column-pivoted QR factorization fallback on failed LU factorization
In action: ```julia using LinearSolve A = [1.0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0] b = rand(4) prob = LinearProblem(A, b) sol = solve(prob, LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false)) @test sol.retcode === ReturnCode.Failure @test sol.u == zeros(4) sol = solve(prob) @test sol.u ≈ svd(A)\b ``` Previously it would just fail, now it falls back to the column-pivoted QR in order to successfully factorize when necessary.
1 parent 9944f34 commit bb73490

File tree

4 files changed

+67
-6
lines changed

4 files changed

+67
-6
lines changed

docs/src/solvers/solvers.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ use `Krylov_GMRES()`.
8282

8383
## Full List of Methods
8484

85+
### Polyalgorithms
86+
87+
```@docs
88+
LinearSolve.DefaultLinearSolver
89+
```
90+
8591
### RecursiveFactorization.jl
8692

8793
!!! note

src/LinearSolve.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,23 @@ EnumX.@enumx DefaultAlgorithmChoice begin
119119
KrylovJL_LSMR
120120
end
121121

122+
"""
123+
DefaultLinearSolver(;safetyfallback=true)
124+
125+
The default linear solver. This is the algorithm chosen when `solve(prob)`
126+
is called. It's a polyalgorithm that detects the optimal method for a given
127+
`A, b` and hardware (Intel, AMD, GPU, etc.).
128+
129+
## Keyword Arguments
130+
131+
* `safetyfallback`: determines whether to fallback to a column-pivoted QR factorization
132+
when an LU factorization fails. This can be required if `A` is rank-deficient. Defaults
133+
to true.
134+
"""
122135
struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
123136
alg::DefaultAlgorithmChoice.T
137+
safetyfallback::Bool
138+
DefaultLinearSolver(alg; safetyfallback=true) = new(alg,safetyfallback)
124139
end
125140

126141
const BLASELTYPES = Union{Float32, Float64, ComplexF32, ComplexF64}

src/default.jl

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ end
4141
ex = Expr(:if, ex.args...)
4242
end
4343

44+
# Handle special case of Column-pivoted QR fallback for LU
45+
function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v::QRFactorization{ColumnNorm})
46+
47+
end
48+
4449
# Legacy fallback
4550
# For SciML algorithms already using `defaultalg`, all assume square matrix.
4651
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true))
@@ -352,11 +357,32 @@ end
352357
kwargs...)
353358
ex = :()
354359
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
355-
newex = quote
356-
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
357-
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
358-
retcode = sol.retcode,
359-
iters = sol.iters, stats = sol.stats)
360+
if alg in Symbol.((DefaultAlgorithmChoice.LUFactorization,
361+
DefaultAlgorithmChoice.RFLUFactorization,
362+
DefaultAlgorithmChoice.MKLLUFactorization,
363+
DefaultAlgorithmChoice.AppleAccelerateLUFactorization,
364+
DefaultAlgorithmChoice.GenericLUFactorization))
365+
newex = quote
366+
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
367+
if sol.retcode === ReturnCode.Failure && alg.safetyfallback
368+
## TODO: Add verbosity logging here about using the fallback
369+
sol = SciMLBase.solve!(cache, QRFactorization(ColumnNorm()), args...; kwargs...)
370+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
371+
retcode = sol.retcode,
372+
iters = sol.iters, stats = sol.stats)
373+
else
374+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
375+
retcode = sol.retcode,
376+
iters = sol.iters, stats = sol.stats)
377+
end
378+
end
379+
else
380+
newex = quote
381+
sol = SciMLBase.solve!(cache, $(algchoice_to_alg(alg)), args...; kwargs...)
382+
SciMLBase.build_linear_solution(alg, sol.u, sol.resid, sol.cache;
383+
retcode = sol.retcode,
384+
iters = sol.iters, stats = sol.stats)
385+
end
360386
end
361387
alg_enum = getproperty(LinearSolve.DefaultAlgorithmChoice, alg)
362388
ex = if ex == :()

test/default_algs.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,18 @@ prob = LinearProblem(A, b)
158158
@test_broken SciMLBase.successful_retcode(solve(prob))
159159

160160
prob2 = LinearProblem(A2, b)
161-
@test SciMLBase.successful_retcode(solve(prob2))
161+
@test SciMLBase.successful_retcode(solve(prob2))
162+
163+
# Column-Pivoted QR fallback on failed LU
164+
A = [1.0 0 0 0
165+
0 1 0 0
166+
0 0 1 0
167+
0 0 0 0]
168+
b = rand(4)
169+
prob = LinearProblem(A, b)
170+
sol = solve(prob, LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization; safetyfallback=false))
171+
@test sol.retcode === ReturnCode.Failure
172+
@test sol.u == zeros(4)
173+
174+
sol = solve(prob)
175+
@test sol.u svd(A)\b

0 commit comments

Comments
 (0)