Skip to content

Commit d315d13

Browse files
match the v1.13 format and setup different versions for generic_lufact
1 parent 7e8c2fc commit d315d13

File tree

2 files changed

+120
-47
lines changed

2 files changed

+120
-47
lines changed

src/factorization.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::GenericLUFactoriz
192192
if length(ipiv) != min(size(A)...)
193193
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...))
194194
end
195-
fact = generic_lufact!(A, alg.pivot; check = false, ipiv)
195+
fact = generic_lufact!(A, alg.pivot, ipiv; check = false)
196196
cache.cacheval = (fact, ipiv)
197197

198198
if !LinearAlgebra.issuccess(fact)
@@ -221,6 +221,7 @@ function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization},
221221
return lu(A; check = false)
222222
else
223223
A isa GPUArraysCore.AnyGPUArray && return nothing
224+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, 0)
224225
return LinearAlgebra.generic_lufact!(copy(A), alg.pivot; check = false)
225226
end
226227
end

src/generic_lufact.jl

Lines changed: 118 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,134 @@
11
# From LinearAlgebra.lu.jl
22
# Modified to be non-allocating
3-
function generic_lufact!(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = lupivottype(T);
4-
check::Bool = true, ipiv = Vector{BlasInt}(undef, min(size(A)...))) where {T}
5-
check && LinearAlgebra.LAPACK.chkfinite(A)
6-
# Extract values
7-
m, n = size(A)
8-
minmn = min(m,n)
3+
@static if VERSION < v"1.11"
4+
function generic_lufact!(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = LinearAlgebra.lupivottype(T),
5+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...));
6+
check::Bool = true, allowsingular::Bool = false) where {T}
7+
check && LinearAlgebra.LAPACK.chkfinite(A)
8+
# Extract values
9+
m, n = size(A)
10+
minmn = min(m,n)
911

10-
# Initialize variables
11-
info = 0
12-
13-
@inbounds begin
14-
for k = 1:minmn
15-
# find index max
16-
kp = k
17-
if pivot === LinearAlgebra.RowMaximum() && k < m
18-
amax = abs(A[k, k])
19-
for i = k+1:m
20-
absi = abs(A[i,k])
21-
if absi > amax
22-
kp = i
23-
amax = absi
12+
# Initialize variables
13+
info = 0
14+
15+
@inbounds begin
16+
for k = 1:minmn
17+
# find index max
18+
kp = k
19+
if pivot === LinearAlgebra.RowMaximum() && k < m
20+
amax = abs(A[k, k])
21+
for i = k+1:m
22+
absi = abs(A[i,k])
23+
if absi > amax
24+
kp = i
25+
amax = absi
26+
end
2427
end
28+
elseif pivot === LinearAlgebra.RowNonZero()
29+
for i = k:m
30+
if !iszero(A[i,k])
31+
kp = i
32+
break
33+
end
34+
end
35+
end
36+
ipiv[k] = kp
37+
if !iszero(A[kp,k])
38+
if k != kp
39+
# Interchange
40+
for i = 1:n
41+
tmp = A[k,i]
42+
A[k,i] = A[kp,i]
43+
A[kp,i] = tmp
44+
end
45+
end
46+
# Scale first column
47+
Akkinv = inv(A[k,k])
48+
for i = k+1:m
49+
A[i,k] *= Akkinv
50+
end
51+
elseif info == 0
52+
info = k
2553
end
26-
elseif pivot === LinearAlgebra.RowNonZero()
27-
for i = k:m
28-
if !iszero(A[i,k])
29-
kp = i
30-
break
54+
# Update the rest
55+
for j = k+1:n
56+
for i = k+1:m
57+
A[i,j] -= A[i,k]*A[k,j]
3158
end
3259
end
3360
end
34-
ipiv[k] = kp
35-
if !iszero(A[kp,k])
36-
if k != kp
37-
# Interchange
38-
for i = 1:n
39-
tmp = A[k,i]
40-
A[k,i] = A[kp,i]
41-
A[kp,i] = tmp
61+
end
62+
check && LinearAlgebra.checknonsingular(info, pivot)
63+
return LinearAlgebra.LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(LinearAlgebra.BlasInt, info))
64+
end
65+
elseif VERSION < v"1.13"
66+
function generic_lufact!(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot,RowNonZero} = LinearAlgebra.lupivottype(T),
67+
ipiv = Vector{LinearAlgebra.BlasInt}(undef, min(size(A)...));
68+
check::Bool = true, allowsingular::Bool = false) where {T}
69+
check && LAPACK.chkfinite(A)
70+
# Extract values
71+
m, n = size(A)
72+
minmn = min(m,n)
73+
74+
# Initialize variables
75+
info = 0
76+
77+
@inbounds begin
78+
for k = 1:minmn
79+
# find index max
80+
kp = k
81+
if pivot === LinearAlgebra.RowMaximum() && k < m
82+
amax = abs(A[k, k])
83+
for i = k+1:m
84+
absi = abs(A[i,k])
85+
if absi > amax
86+
kp = i
87+
amax = absi
88+
end
89+
end
90+
elseif pivot === LinearAlgebra.RowNonZero()
91+
for i = k:m
92+
if !iszero(A[i,k])
93+
kp = i
94+
break
95+
end
4296
end
4397
end
44-
# Scale first column
45-
Akkinv = inv(A[k,k])
46-
for i = k+1:m
47-
A[i,k] *= Akkinv
98+
ipiv[k] = kp
99+
if !iszero(A[kp,k])
100+
if k != kp
101+
# Interchange
102+
for i = 1:n
103+
tmp = A[k,i]
104+
A[k,i] = A[kp,i]
105+
A[kp,i] = tmp
106+
end
107+
end
108+
# Scale first column
109+
Akkinv = inv(A[k,k])
110+
for i = k+1:m
111+
A[i,k] *= Akkinv
112+
end
113+
elseif info == 0
114+
info = k
48115
end
49-
elseif info == 0
50-
info = k
51-
end
52-
# Update the rest
53-
for j = k+1:n
54-
for i = k+1:m
55-
A[i,j] -= A[i,k]*A[k,j]
116+
# Update the rest
117+
for j = k+1:n
118+
for i = k+1:m
119+
A[i,j] -= A[i,k]*A[k,j]
120+
end
56121
end
57122
end
58123
end
124+
if pivot === LinearAlgebra.NoPivot()
125+
# Use a negative value to distinguish a failed factorization (zero in pivot
126+
# position during unpivoted LU) from a valid but rank-deficient factorization
127+
info = -info
128+
end
129+
check && LinearAlgebra._check_lu_success(info, allowsingular)
130+
return LinearAlgebra.LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(LinearAlgebra.BlasInt, info))
59131
end
60-
check && LinearAlgebra.checknonsingular(info, pivot)
61-
return LinearAlgebra.LU{T,typeof(A),typeof(ipiv)}(A, ipiv, convert(LinearAlgebra.BlasInt, info))
132+
else
133+
generic_lufact!(args...; kwargs...) = LinearAlgebra.generic_lufact!(args...; kwargs...)
62134
end

0 commit comments

Comments
 (0)