Skip to content

Commit 1f8eb0d

Browse files
Merge pull request #505 from j-fu/rework_iparm_setting
Rework iparm setting
2 parents 270b56d + a339509 commit 1f8eb0d

File tree

3 files changed

+93
-42
lines changed

3 files changed

+93
-42
lines changed

ext/LinearSolvePardisoExt.jl

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,28 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
2222
reltol,
2323
verbose::Bool,
2424
assumptions::LinearSolve.OperatorAssumptions)
25-
@unpack nprocs, solver_type, matrix_type, iparm, dparm = alg
25+
@unpack nprocs, solver_type, matrix_type, cache_analysis, iparm, dparm = alg
2626
A = convert(AbstractMatrix, A)
2727

28+
transposed_iparm = 1
2829
solver = if Pardiso.PARDISO_LOADED[]
2930
solver = Pardiso.PardisoSolver()
31+
Pardiso.pardisoinit(solver)
3032
solver_type !== nothing && Pardiso.set_solver!(solver, solver_type)
3133

3234
solver
3335
else
3436
solver = Pardiso.MKLPardisoSolver()
37+
Pardiso.pardisoinit(solver)
3538
nprocs !== nothing && Pardiso.set_nprocs!(solver, nprocs)
3639

40+
# for mkl 1 means conjugated and 2 means transposed.
41+
# https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2024-0/pardiso-iparm-parameter.html#IPARM37
42+
transposed_iparm = 2
43+
3744
solver
3845
end
3946

40-
Pardiso.pardisoinit(solver) # default initialization
41-
4247
if matrix_type !== nothing
4348
Pardiso.set_matrixtype!(solver, matrix_type)
4449
else
@@ -52,22 +57,6 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
5257
end
5358
verbose && Pardiso.set_msglvl!(solver, Pardiso.MESSAGE_LEVEL_ON)
5459

55-
# pass in vector of tuples like [(iparm::Int, key::Int) ...]
56-
if iparm !== nothing
57-
for i in iparm
58-
Pardiso.set_iparm!(solver, i...)
59-
end
60-
end
61-
62-
if dparm !== nothing
63-
for d in dparm
64-
Pardiso.set_dparm!(solver, d...)
65-
end
66-
end
67-
68-
# Make sure to say it's transposed because its CSC not CSR
69-
Pardiso.set_iparm!(solver, 12, 1)
70-
7160
#=
7261
Note: It is recommended to use IPARM(11)=1 (scaling) and IPARM(13)=1 (matchings) for
7362
highly indefinite symmetric matrices e.g. from interior point optimizations or saddle point problems.
@@ -79,23 +68,44 @@ function LinearSolve.init_cacheval(alg::PardisoJL,
7968
be changed to Pardiso.ANALYSIS_NUM_FACT in the solver loop otherwise instabilities
8069
occur in the example https://github.com/SciML/OrdinaryDiffEq.jl/issues/1569
8170
=#
82-
Pardiso.set_iparm!(solver, 11, 0)
83-
Pardiso.set_iparm!(solver, 13, 0)
84-
85-
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
71+
if cache_analysis
72+
Pardiso.set_iparm!(solver, 11, 0)
73+
Pardiso.set_iparm!(solver, 13, 0)
74+
end
8675

8776
if alg.solver_type == 1
8877
# PARDISO uses a numerical factorization A = LU for the first system and
8978
# applies these exact factors L and U for the next steps in a
9079
# preconditioned Krylov-Subspace iteration. If the iteration does not
9180
# converge, the solver will automatically switch back to the numerical factorization.
92-
Pardiso.set_iparm!(solver, 3, round(Int, abs(log10(reltol)), RoundDown) * 10 + 1)
81+
# Be aware that in the intel docs, iparm indexes are one lower.
82+
Pardiso.set_iparm!(solver, 4, round(Int, abs(log10(reltol)), RoundDown) * 10 + 1)
9383
end
9484

95-
Pardiso.pardiso(solver,
96-
u,
97-
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
98-
b)
85+
# pass in vector of tuples like [(iparm::Int, key::Int) ...]
86+
if iparm !== nothing
87+
for i in iparm
88+
Pardiso.set_iparm!(solver, i...)
89+
end
90+
end
91+
92+
if dparm !== nothing
93+
for d in dparm
94+
Pardiso.set_dparm!(solver, d...)
95+
end
96+
end
97+
98+
# Make sure to say it's transposed because its CSC not CSR
99+
# This is also the only value which should not be overwritten by users
100+
Pardiso.set_iparm!(solver, 12, transposed_iparm)
101+
102+
if cache_analysis
103+
Pardiso.set_phase!(solver, Pardiso.ANALYSIS)
104+
Pardiso.pardiso(solver,
105+
u,
106+
SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
107+
b)
108+
end
99109

100110
return solver
101111
end
@@ -105,13 +115,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::PardisoJL; kwargs
105115
A = convert(AbstractMatrix, A)
106116

107117
if cache.isfresh
108-
Pardiso.set_phase!(cache.cacheval, Pardiso.NUM_FACT)
118+
phase = alg.cache_analysis ? Pardiso.NUM_FACT : Pardiso.ANALYSIS_NUM_FACT
119+
Pardiso.set_phase!(cache.cacheval, phase)
109120
Pardiso.pardiso(cache.cacheval, A, eltype(A)[])
110121
cache.isfresh = false
111122
end
112-
113123
Pardiso.set_phase!(cache.cacheval, Pardiso.SOLVE_ITERATIVE_REFINE)
114124
Pardiso.pardiso(cache.cacheval, u, A, b)
125+
115126
return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
116127
end
117128

src/extension_algs.jl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ end
8686
```julia
8787
MKLPardisoFactorize(; nprocs::Union{Int, Nothing} = nothing,
8888
matrix_type = nothing,
89+
cache_analysis = false,
8990
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
9091
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing)
9192
```
@@ -98,7 +99,11 @@ A sparse factorization method using MKL Pardiso.
9899
99100
## Keyword Arguments
100101
101-
For the definition of the keyword arguments, see the Pardiso.jl documentation.
102+
Setting `cache_analysis = true` disables Pardiso's scaling and matching defaults
103+
and caches the result of the initial analysis phase for all further computations
104+
with this solver.
105+
106+
For the definition of the other keyword arguments, see the Pardiso.jl documentation.
102107
All values default to `nothing` and the solver internally determines the values
103108
given the input types, and these keyword arguments are only for overriding the
104109
default handling process. This should not be required by most users.
@@ -109,6 +114,7 @@ MKLPardisoFactorize(; kwargs...) = PardisoJL(; solver_type = 0, kwargs...)
109114
```julia
110115
MKLPardisoIterate(; nprocs::Union{Int, Nothing} = nothing,
111116
matrix_type = nothing,
117+
cache_analysis = false,
112118
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
113119
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing)
114120
```
@@ -121,7 +127,11 @@ A mixed factorization+iterative method using MKL Pardiso.
121127
122128
## Keyword Arguments
123129
124-
For the definition of the keyword arguments, see the Pardiso.jl documentation.
130+
Setting `cache_analysis = true` disables Pardiso's scaling and matching defaults
131+
and caches the result of the initial analysis phase for all further computations
132+
with this solver.
133+
134+
For the definition of the other keyword arguments, see the Pardiso.jl documentation.
125135
All values default to `nothing` and the solver internally determines the values
126136
given the input types, and these keyword arguments are only for overriding the
127137
default handling process. This should not be required by most users.
@@ -133,6 +143,7 @@ MKLPardisoIterate(; kwargs...) = PardisoJL(; solver_type = 1, kwargs...)
133143
PardisoJL(; nprocs::Union{Int, Nothing} = nothing,
134144
solver_type = nothing,
135145
matrix_type = nothing,
146+
cache_analysis = false,
136147
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
137148
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing)
138149
```
@@ -145,6 +156,10 @@ A generic method using MKL Pardiso. Specifying `solver_type` is required.
145156
146157
## Keyword Arguments
147158
159+
Setting `cache_analysis = true` disables Pardiso's scaling and matching defaults
160+
and caches the result of the initial analysis phase for all further computations
161+
with this solver.
162+
148163
For the definition of the keyword arguments, see the Pardiso.jl documentation.
149164
All values default to `nothing` and the solver internally determines the values
150165
given the input types, and these keyword arguments are only for overriding the
@@ -154,14 +169,16 @@ struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm
154169
nprocs::Union{Int, Nothing}
155170
solver_type::T1
156171
matrix_type::T2
172+
cache_analysis::Bool
157173
iparm::Union{Vector{Tuple{Int, Int}}, Nothing}
158174
dparm::Union{Vector{Tuple{Int, Int}}, Nothing}
159175

160176
function PardisoJL(; nprocs::Union{Int, Nothing} = nothing,
161177
solver_type = nothing,
162178
matrix_type = nothing,
179+
cache_analysis = false,
163180
iparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing,
164-
dparm::Union{Vector{Tuple{Int, Int}}, Nothing} = nothing)
181+
dparm::Union{Vector{Tuple{Int, Float64}}, Nothing} = nothing)
165182
ext = Base.get_extension(@__MODULE__, :LinearSolvePardisoExt)
166183
if ext === nothing
167184
error("PardisoJL requires that Pardiso is loaded, i.e. `using Pardiso`")
@@ -170,7 +187,8 @@ struct PardisoJL{T1, T2} <: LinearSolve.SciMLLinearSolveAlgorithm
170187
T2 = typeof(matrix_type)
171188
@assert T1 <: Union{Int, Nothing, ext.Pardiso.Solver}
172189
@assert T2 <: Union{Int, Nothing, ext.Pardiso.MatrixType}
173-
return new{T1, T2}(nprocs, solver_type, matrix_type, iparm, dparm)
190+
return new{T1, T2}(
191+
nprocs, solver_type, matrix_type, cache_analysis, iparm, dparm)
174192
end
175193
end
176194
end

test/pardiso/pardiso.jl

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using LinearSolve, SparseArrays, Random
1+
using LinearSolve, SparseArrays, Random, LinearAlgebra
22
import Pardiso
33

44
A1 = sparse([1.0 0 -2 3
@@ -14,19 +14,22 @@ e = ones(n)
1414
e2 = ones(n - 1)
1515
A2 = spdiagm(-1 => im * e2, 0 => lambda * e, 1 => -im * e2)
1616
b2 = rand(n) + im * zeros(n)
17-
cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)
18-
1917
prob2 = LinearProblem(A2, b2)
2018

19+
cache_kwargs = (; abstol = 1e-8, reltol = 1e-8, maxiter = 30)
20+
2121
for alg in (PardisoJL(), MKLPardisoFactorize(), MKLPardisoIterate())
22-
u = solve(prob1, alg; cache_kwargs...).u
23-
@test A1 * u b1
22+
u = solve(prob1, alg; cache_kwargs...).u
23+
@test A1 * u b1
2424

25-
u = solve(prob2, alg; cache_kwargs...).u
26-
@test eltype(u) <: Complex
27-
@test_broken A2 * u b2
25+
u = solve(prob2, alg; cache_kwargs...).u
26+
@test eltype(u) <: Complex
27+
@test A2 * u b2
2828
end
2929

30+
return
31+
32+
3033
Random.seed!(10)
3134
A = sprand(n, n, 0.8);
3235
A2 = 2.0 .* A;
@@ -53,6 +56,25 @@ sol33 = solve(linsolve)
5356
@test sol12.u sol32.u
5457
@test sol13.u sol33.u
5558

59+
60+
# Test for problem from #497
61+
function makeA()
62+
n = 60
63+
colptr = [1, 4, 7, 11, 15, 17, 22, 26, 30, 34, 38, 40, 46, 50, 54, 58, 62, 64, 70, 74, 78, 82, 86, 88, 94, 98, 102, 106, 110, 112, 118, 122, 126, 130, 134, 136, 142, 146, 150, 154, 158, 160, 166, 170, 174, 178, 182, 184, 190, 194, 198, 202, 206, 208, 214, 218, 222, 224, 226, 228, 232]
64+
rowval = [1, 3, 4, 1, 2, 4, 2, 4, 9, 10, 3, 5, 11, 12, 1, 3, 2, 4, 6, 11, 12, 2, 7, 9, 10, 2, 7, 8, 10, 8, 10, 15, 16, 9, 11, 17, 18, 7, 9, 2, 8, 10, 12, 17, 18, 8, 13, 15, 16, 8, 13, 14, 16, 14, 16, 21, 22, 15, 17, 23, 24, 13, 15, 8, 14, 16, 18, 23, 24, 14, 19, 21, 22, 14, 19, 20, 22, 20, 22, 27, 28, 21, 23, 29, 30, 19, 21, 14, 20, 22, 24, 29, 30, 20, 25, 27, 28, 20, 25, 26, 28, 26, 28, 33, 34, 27, 29, 35, 36, 25, 27, 20, 26, 28, 30, 35, 36, 26, 31, 33, 34, 26, 31, 32, 34, 32, 34, 39, 40, 33, 35, 41, 42, 31, 33, 26, 32, 34, 36, 41, 42, 32, 37, 39, 40, 32, 37, 38, 40, 38, 40, 45, 46, 39, 41, 47, 48, 37, 39, 32, 38, 40, 42, 47, 48, 38, 43, 45, 46, 38, 43, 44, 46, 44, 46, 51, 52, 45, 47, 53, 54, 43, 45, 38, 44, 46, 48, 53, 54, 44, 49, 51, 52, 44, 49, 50, 52, 50, 52, 57, 58, 51, 53, 59, 60, 49, 51, 44, 50, 52, 54, 59, 60, 50, 55, 57, 58, 50, 55, 56, 58, 56, 58, 57, 59, 55, 57, 50, 56, 58, 60]
65+
nzval = [-0.64, 1.0, -1.0, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -0.03510101010101016, -0.975, -1.0806825309567203, 1.0, -0.95, -0.025, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0, -0.025, -0.95, -0.3564, -0.64, 1.0, -1.0, 13.792569659442691, 0.8606811145510832, -13.792569659442691, 1.0, 0.03475000000000006, 1.0, -1.0806825309567203, 1.0, 2.370597639417811, -2.3705976394178108, 10.698449178570607, -11.083604432603583, -0.2770901108150896, 1.0]
66+
A = SparseMatrixCSC(n, n, colptr, rowval, nzval)
67+
return(A)
68+
end
69+
70+
A=makeA()
71+
u0=fill(0.1,size(A,2))
72+
linprob = LinearProblem(A, A*u0)
73+
u = LinearSolve.solve(linprob, PardisoJL())
74+
@test norm(u-u0) < 1.0e-14
75+
76+
77+
5678
# Testing and demonstrating Pardiso.set_iparm! for MKLPardisoSolver
5779
solver = Pardiso.MKLPardisoSolver()
5880
iparm = [

0 commit comments

Comments
 (0)