Skip to content

Implement RowMaximum Pivoting Strategy for Distributed LU Factorization #631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d508952
Implement RowMaximum Pivoting Strategy for Distributed LU Factorization
AkhilAkkapelli Jul 8, 2025
1242052
Remove debug print statements from pivoting functions in LU factoriza…
AkhilAkkapelli Jul 8, 2025
bb1620d
Refactor function signatures for improved clarity and consistency
AkhilAkkapelli Jul 9, 2025
11ccee7
Add singularity check for robust LU factorization
AkhilAkkapelli Jul 9, 2025
11042f9
Fix lu.jl
AkhilAkkapelli Jul 10, 2025
794e554
Fix precision issue in singularity check
AkhilAkkapelli Jul 10, 2025
afe96f7
Fix pivot search by replacing argmax with BLAS.iamax
AkhilAkkapelli Jul 10, 2025
e636bbb
Refactor LU tests for consistency and clarity in pivot checks
AkhilAkkapelli Jul 10, 2025
337e59d
Minor LU cleanup
jpsamaroo Jul 10, 2025
ffdf186
test/array/linalg/lu: Test for invalid block sizes
jpsamaroo Jul 10, 2025
67fe600
docs: Add LU RowMaximum to supported ops
jpsamaroo Jul 10, 2025
adf11fc
docs: Add SparseArrays sprand to supported ops
jpsamaroo Jul 10, 2025
330a9e3
test/array/linalg/lu: Fix skipped tests
jpsamaroo Jul 10, 2025
f06b594
Fix logical condition for NoPivot stability check
AkhilAkkapelli Jul 11, 2025
ffb03bb
Enhance LU decomposition functions to support allowsingular parameter…
AkhilAkkapelli Jul 11, 2025
c147134
Fix singular value test to use ones instead of rand for RowMaximum pivot
AkhilAkkapelli Jul 11, 2025
d116d51
Refactor LU decomposition functions to handle unequal block sizes and…
AkhilAkkapelli Jul 11, 2025
ce2480e
Update LU tests to support non-square block sizes and improve singula…
AkhilAkkapelli Jul 11, 2025
09b3882
Fix logical condition for singular value exception check in LU tests
AkhilAkkapelli Jul 11, 2025
d508e96
Fix comment for singular value exception check in LU tests
AkhilAkkapelli Jul 11, 2025
c717502
Fix logical condition for NoPivot check in LU tests
AkhilAkkapelli Jul 11, 2025
2d2e122
Fix argument passing for NoPivot in generic_lufact! calls in LU decom…
AkhilAkkapelli Jul 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/src/darray.md
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,9 @@ From `Base`:
From `Random`:
- `rand!`/`randn!`

From `SparseArrays`:
- `sprand`

From `Statistics`:
- `mean`
- `var`
Expand All @@ -694,7 +697,7 @@ From `LinearAlgebra`:
- `*` (Out-of-place Matrix-(Matrix/Vector) multiply)
- `mul!` (In-place Matrix-Matrix multiply)
- `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization)
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` only))
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` and `RowMaximum` only))

From `AbstractFFTs`:
- `fft`/`fft!`
Expand Down
2 changes: 1 addition & 1 deletion src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import SparseArrays: sprand, SparseMatrixCSC
import MemPool
import MemPool: DRef, FileRef, poolget, poolset

import Base: collect, reduce
import Base: collect, reduce, view

import LinearAlgebra
import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric
Expand Down
138 changes: 132 additions & 6 deletions src/array/lu.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
function LinearAlgebra.lu(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=true) where T
LinearAlgebra.lu(A::DMatrix{T}, pivot::Union{LinearAlgebra.RowMaximum,LinearAlgebra.NoPivot} = LinearAlgebra.RowMaximum(); check::Bool=true, allowsingular::Bool=false) where {T<:LinearAlgebra.BlasFloat} = LinearAlgebra.lu(A, pivot; check=check, allowsingular=allowsingular)

LinearAlgebra.lu!(A::DMatrix{T}, pivot::Union{LinearAlgebra.RowMaximum,LinearAlgebra.NoPivot} = LinearAlgebra.RowMaximum(); check::Bool=true, allowsingular::Bool=false) where {T<:LinearAlgebra.BlasFloat} = LinearAlgebra.lu(A, pivot; check=check, allowsingular=allowsingular)

function LinearAlgebra.lu(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool = true, allowsingular::Bool = false) where {T<:LinearAlgebra.BlasFloat}
A_copy = LinearAlgebra._lucopy(A, LinearAlgebra.lutype(T))
return LinearAlgebra.lu!(A_copy, LinearAlgebra.NoPivot(); check=check)
end
function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=true) where T
function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool = true, allowsingular::Bool = false) where {T<:LinearAlgebra.BlasFloat}

check && LinearAlgebra.LAPACK.chkfinite(A)

zone = one(T)
mzone = -one(T)

mb, nb = A.partitioning.blocksize

if mb != nb
mb = nb = min(mb, nb)
A = maybe_copy_buffered(A => Blocks(nb, nb)) do A
A
end
end

Ac = A.chunks
mt, nt = size(Ac)
iscomplex = T <: Complex
trans = iscomplex ? 'C' : 'T'

info = 0

Dagger.spawn_datadeps() do
for k in range(1, min(mt, nt))
Dagger.@spawn LinearAlgebra.generic_lufact!(InOut(Ac[k, k]), LinearAlgebra.NoPivot(); check)
Dagger.@spawn LinearAlgebra.generic_lufact!(InOut(Ac[k, k]), LinearAlgebra.NoPivot(); check=check, allowsingular=allowsingular)
for m in range(k+1, mt)
Dagger.@spawn BLAS.trsm!('R', 'U', 'N', 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]))
end
Expand All @@ -27,5 +44,114 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=t

ipiv = DVector([i for i in 1:min(size(A)...)])

return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, 0)
check && LinearAlgebra._check_lu_success(info, allowsingular)

return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, info)
end

function searchmax_pivot!(piv_idx::AbstractVector{Int}, piv_val::AbstractVector{T}, A::AbstractMatrix{T}, offset::Int=0) where T
max_idx = LinearAlgebra.BLAS.iamax(A[:])
piv_idx[1] = offset+max_idx
piv_val[1] = A[max_idx]
end

function update_ipiv!(ipivl::AbstractVector{Int}, info::Ref{Int}, piv_idx::AbstractVector{Int}, piv_val::AbstractVector{T}, k::Int, nb::Int) where T
max_piv_idx = LinearAlgebra.BLAS.iamax(piv_val)
max_piv_val = piv_val[max_piv_idx]
abs_max_piv_val = max_piv_val isa Real ? abs(max_piv_val) : abs(real(max_piv_val)) + abs(imag(max_piv_val))
if isapprox(abs_max_piv_val, zero(T); atol=eps(real(T)))
info[] = k
end
ipivl[1] = (max_piv_idx+k-2)*nb + piv_idx[max_piv_idx]
end

function swaprows_panel!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipivl::AbstractVector{Int}, m::Int, p::Int, nb::Int) where T
q = div(ipivl[1]-1,nb) + 1
r = (ipivl[1]-1)%nb+1
if m == q
A[p,:], M[r,:] = M[r,:], A[p,:]
end
end

function update_panel!(M::AbstractMatrix{T}, A::AbstractMatrix{T}, p::Int) where T
Acinv = one(T) / A[p,p]
LinearAlgebra.BLAS.scal!(Acinv, view(M, :, p))
LinearAlgebra.BLAS.ger!(-one(T), view(M, :, p), conj.(view(A, p, p+1:size(A,2))), view(M, :, p+1:size(M,2)))
end

function swaprows_trail!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipiv::AbstractVector{Int}, m::Int, nb::Int) where T
for p in eachindex(ipiv)
q = div(ipiv[p]-1,nb) + 1
r = (ipiv[p]-1)%nb+1
if m == q
A[p,:], M[r,:] = M[r,:], A[p,:]
end
end
end

function LinearAlgebra.lu(A::DMatrix{T}, ::LinearAlgebra.RowMaximum; check::Bool = true, allowsingular::Bool = false) where {T<:LinearAlgebra.BlasFloat}
A_copy = LinearAlgebra._lucopy(A, LinearAlgebra.lutype(T))
return LinearAlgebra.lu!(A_copy, LinearAlgebra.RowMaximum(); check=check, allowsingular=allowsingular)
end
function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.RowMaximum; check::Bool = true, allowsingular::Bool = false) where {T<:LinearAlgebra.BlasFloat}

check && LinearAlgebra.LAPACK.chkfinite(A)

zone = one(T)
mzone = -one(T)

mb, nb = A.partitioning.blocksize

if mb != nb
mb = nb = min(mb, nb)
A = maybe_copy_buffered(A => Blocks(nb, nb)) do A
A
end
end

Ac = A.chunks
mt, nt = size(Ac)
m, n = size(A)

ipiv = DVector(collect(1:min(m, n)), Blocks(nb))
ipivc = ipiv.chunks

info = Ref(0)

max_piv_idx = zeros(Int,mt)
max_piv_val = zeros(T, mt)

Dagger.spawn_datadeps() do
for k in 1:min(mt, nt)
for p in 1:min(nb, m-(k-1)*nb, n-(k-1)*nb)
Dagger.@spawn searchmax_pivot!(Out(view(max_piv_idx, k:k)), Out(view(max_piv_val, k:k)), In(view(Ac[k,k],p:min(nb,m-(k-1)*nb),p:p)), p-1)
for i in k+1:mt
Dagger.@spawn searchmax_pivot!(Out(view(max_piv_idx, i:i)), Out(view(max_piv_val, i:i)), In(view(Ac[i,k],:,p:p)))
end
Dagger.@spawn update_ipiv!(InOut(view(ipivc[k],p:p)), InOut(info), In(view(max_piv_idx, k:mt)), In(view(max_piv_val, k:mt)), k, nb)
for i in k:mt
Dagger.@spawn swaprows_panel!(InOut(Ac[k, k]), InOut(Ac[i, k]), In(view(ipivc[k],p:p)), i, p, nb)
end
Dagger.@spawn update_panel!(InOut(view(Ac[k,k],p+1:min(nb,m-(k-1)*nb),:)), In(Ac[k,k]), p)
for i in k+1:mt
Dagger.@spawn update_panel!(InOut(Ac[i, k]), In(Ac[k,k]), p)
end
end
for j in Iterators.flatten((1:k-1, k+1:nt))
for i in k:mt
Dagger.@spawn swaprows_trail!(InOut(Ac[k, j]), InOut(Ac[i, j]), In(ipivc[k]), i, mb)
end
end
for j in k+1:nt
Dagger.@spawn BLAS.trsm!('L', 'L', 'N', 'U', zone, In(Ac[k, k]), InOut(Ac[k, j]))
for i in k+1:mt
Dagger.@spawn BLAS.gemm!('N', 'N', mzone, In(Ac[i, k]), In(Ac[k, j]), zone, InOut(Ac[i, j]))
end
end
end
end

check && LinearAlgebra._check_lu_success(info[], allowsingular)

return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, info[])
end
23 changes: 15 additions & 8 deletions test/array/linalg/lu.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
@testset "$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "$T with $pivot" for T in (Float32, Float64, ComplexF32, ComplexF64), pivot in (NoPivot(), RowMaximum())
A = rand(T, 128, 128)
B = copy(A)
DA = view(A, Blocks(64, 64))

# Out-of-place
lu_A = lu(A, NoPivot())
lu_DA = lu(DA, NoPivot())
lu_A = lu(A, pivot)
lu_DA = lu(DA, pivot)
@test lu_DA isa LU{T,DMatrix{T},DVector{Int}}
if !(T in (Float32, ComplexF32)) # FIXME: NoPivot is unstable for FP32
if !(T in (Float32, ComplexF32) && pivot == NoPivot()) # FIXME: NoPivot is unstable for FP32
@test lu_A.L ≈ lu_DA.L
@test lu_A.U ≈ lu_DA.U
end
Expand All @@ -18,10 +18,10 @@

# In-place
A_copy = copy(A)
lu_A = lu!(A_copy, NoPivot())
lu_DA = lu!(DA, NoPivot())
lu_A = lu!(A_copy, pivot)
lu_DA = lu!(DA, pivot)
@test lu_DA isa LU{T,DMatrix{T},DVector{Int}}
if !(T in (Float32, ComplexF32)) # FIXME: NoPivot is unstable for FP32
if !(T in (Float32, ComplexF32) && pivot == NoPivot()) # FIXME: NoPivot is unstable for FP32
@test lu_A.L ≈ lu_DA.L
@test lu_A.U ≈ lu_DA.U
end
Expand All @@ -30,4 +30,11 @@
# Check that changes propagated to A
@test DA ≈ A
@test !(B ≈ A)
end

# Non-square block sizes
@test lu(rand(Blocks(64, 32), T, 128, 128), pivot) isa LU{T,DMatrix{T},DVector{Int}}
@test lu!(rand(Blocks(64, 32), T, 128, 128), pivot) isa LU{T,DMatrix{T},DVector{Int}}

# Singular Values
@test_throws LinearAlgebra.SingularException lu(ones(Blocks(64,64), T, 128, 128)) # FIXME: NoPivot needs to handle info
end