Skip to content

Commit 337e59d

Browse files
committed
Minor LU cleanup
1 parent e636bbb commit 337e59d

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

src/array/lu.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=t
77
mzone = -one(T)
88
Ac = A.chunks
99
mt, nt = size(Ac)
10+
mb, nb = A.partitioning.blocksize
11+
12+
mb != nb && error("Unequal block sizes are not supported: mb = $mb, nb = $nb")
1013

1114
Dagger.spawn_datadeps() do
1215
for k in range(1, min(mt, nt))
@@ -39,7 +42,7 @@ function update_ipiv!(ipivl::AbstractVector{Int}, piv_idx::AbstractVector{Int},
3942
max_piv_val = piv_val[max_piv_idx]
4043
abs_max_piv_val = max_piv_val isa Real ? abs(max_piv_val) : abs(real(max_piv_val)) + abs(imag(max_piv_val))
4144
isapprox(abs_max_piv_val, zero(T); atol=eps(real(T))) && throw(LinearAlgebra.SingularException(k))
42-
ipivl[1] = (max_piv_idx+k-2)*nb + piv_idx[max_piv_idx]
45+
ipivl[1] = (max_piv_idx+k-2)*nb + piv_idx[max_piv_idx]
4346
end
4447

4548
function swaprows_panel!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipivl::AbstractVector{Int}, m::Int, p::Int, nb::Int) where T
@@ -51,7 +54,7 @@ function swaprows_panel!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipivl::Abst
5154
end
5255

5356
function update_panel!(M::AbstractMatrix{T}, A::AbstractMatrix{T}, p::Int) where T
54-
Acinv = one(T) / A[p,p]
57+
Acinv = one(T) / A[p,p]
5558
LinearAlgebra.BLAS.scal!(Acinv, view(M, :, p))
5659
LinearAlgebra.BLAS.ger!(-one(T), view(M, :, p), conj.(view(A, p, p+1:size(A,2))), view(M, :, p+1:size(M,2)))
5760
end
@@ -78,15 +81,15 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.RowMaximum; check::Boo
7881
mt, nt = size(Ac)
7982
m, n = size(A)
8083
mb, nb = A.partitioning.blocksize
81-
84+
8285
mb != nb && error("Unequal block sizes are not supported: mb = $mb, nb = $nb")
8386

8487
ipiv = DVector(collect(1:min(m, n)), Blocks(mb))
8588
ipivc = ipiv.chunks
8689

8790
max_piv_idx = zeros(Int,mt)
8891
max_piv_val = zeros(T, mt)
89-
92+
9093
Dagger.spawn_datadeps() do
9194
for k in 1:min(mt, nt)
9295
for p in 1:min(nb, m-(k-1)*nb, n-(k-1)*nb)
@@ -96,17 +99,16 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.RowMaximum; check::Boo
9699
end
97100
Dagger.@spawn update_ipiv!(InOut(view(ipivc[k],p:p)), In(view(max_piv_idx, k:mt)), In(view(max_piv_val, k:mt)), k, nb)
98101
for i in k:mt
99-
Dagger.@spawn swaprows_panel!(InOut(Ac[k, k]), InOut(Ac[i, k]), InOut(view(ipivc[k],p:p)), i, p, nb)
102+
Dagger.@spawn swaprows_panel!(InOut(Ac[k, k]), InOut(Ac[i, k]), In(view(ipivc[k],p:p)), i, p, nb)
100103
end
101104
Dagger.@spawn update_panel!(InOut(view(Ac[k,k],p+1:min(nb,m-(k-1)*nb),:)), In(Ac[k,k]), p)
102105
for i in k+1:mt
103106
Dagger.@spawn update_panel!(InOut(Ac[i, k]), In(Ac[k,k]), p)
104107
end
105-
106108
end
107109
for j in Iterators.flatten((1:k-1, k+1:nt))
108110
for i in k:mt
109-
Dagger.@spawn swaprows_trail!(InOut(Ac[k, j]), InOut(Ac[i, j]), In(ipivc[k]), i, mb)
111+
Dagger.@spawn swaprows_trail!(InOut(Ac[k, j]), InOut(Ac[i, j]), In(ipivc[k]), i, mb)
110112
end
111113
end
112114
for j in k+1:nt
@@ -118,5 +120,5 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.RowMaximum; check::Boo
118120
end
119121
end
120122

121-
return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, 0)
122-
end
123+
return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, 0)
124+
end

test/array/linalg/lu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Out-of-place
77
lu_A = lu(A, pivot)
88
lu_DA = lu(DA, pivot)
9-
@test lu_DA isa LU{T,DMatrix{T},DVector{Int}}
9+
@test lu_DA isa LU{T,DMatrix{T},DVector{Int}}
1010
if !(T in (Float32, ComplexF32)) && pivot == NoPivot() # FIXME: NoPivot is unstable for FP32
1111
@test lu_A.L lu_DA.L
1212
@test lu_A.U lu_DA.U

0 commit comments

Comments
 (0)