Skip to content

tests: Overhaul GPU testing, and test linalg operations on GPUs #635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d508952
Implement RowMaximum Pivoting Strategy for Distributed LU Factorization
AkhilAkkapelli Jul 8, 2025
1242052
Remove debug print statements from pivoting functions in LU factoriza…
AkhilAkkapelli Jul 8, 2025
bb1620d
Refactor function signatures for improved clarity and consistency
AkhilAkkapelli Jul 9, 2025
11ccee7
Add singularity check for robust LU factorization
AkhilAkkapelli Jul 9, 2025
11042f9
Fix lu.jl
AkhilAkkapelli Jul 10, 2025
794e554
Fix precision issue in singularity check
AkhilAkkapelli Jul 10, 2025
afe96f7
Fix pivot search by replacing argmax with BLAS.iamax
AkhilAkkapelli Jul 10, 2025
e636bbb
Refactor LU tests for consistency and clarity in pivot checks
AkhilAkkapelli Jul 10, 2025
337e59d
Minor LU cleanup
jpsamaroo Jul 10, 2025
ffdf186
test/array/linalg/lu: Test for invalid block sizes
jpsamaroo Jul 10, 2025
67fe600
docs: Add LU RowMaximum to supported ops
jpsamaroo Jul 10, 2025
adf11fc
docs: Add SparseArrays sprand to supported ops
jpsamaroo Jul 10, 2025
330a9e3
test/array/linalg/lu: Fix skipped tests
jpsamaroo Jul 10, 2025
bda2224
tests: Refactor GPU testing, test linalg on GPU
jpsamaroo Jul 10, 2025
1731b4d
GPUs: Add ArrayDomain indexing methods
jpsamaroo Jul 10, 2025
3eb1648
fixup! tests: Refactor GPU testing, test linalg on GPU
jpsamaroo Jul 11, 2025
58d041c
fixup! GPUs: Add ArrayDomain indexing methods
jpsamaroo Jul 11, 2025
c72ecf9
DArray/mul: Make copy functions GPU-compatible
jpsamaroo Jul 11, 2025
c2261b9
DArray/alloc: Fix view ambiguity
jpsamaroo Jul 11, 2025
533a3e0
fixup! fixup! tests: Refactor GPU testing, test linalg on GPU
jpsamaroo Jul 11, 2025
1baaeca
fixup! fixup! fixup! tests: Refactor GPU testing, test linalg on GPU
jpsamaroo Jul 11, 2025
043fa19
fixup! DArray/alloc: Fix view ambiguity
jpsamaroo Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/src/darray.md
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,9 @@ From `Base`:
From `Random`:
- `rand!`/`randn!`

From `SparseArrays`:
- `sprand`

From `Statistics`:
- `mean`
- `var`
Expand All @@ -694,7 +697,7 @@ From `LinearAlgebra`:
- `*` (Out-of-place Matrix-(Matrix/Vector) multiply)
- `mul!` (In-place Matrix-Matrix multiply)
- `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization)
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` only))
- `lu`/`lu!` (In-place/Out-of-place LU factorization (`NoPivot` and `RowMaximum` only))

From `AbstractFFTs`:
- `fft`/`fft!`
Expand Down
6 changes: 6 additions & 0 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ struct AllocateUndef{S} end
(::AllocateUndef{S})(T, dims::Dims{N}) where {S,N} = CuArray{S,N}(undef, dims)
Dagger.allocate_array_func(::CuArrayDeviceProc, ::Dagger.AllocateUndef{S}) where S = AllocateUndef{S}()

# Indexing
Base.getindex(arr::CuArray, d::Dagger.ArrayDomain) = arr[Dagger.indexes(d)...]

# Views
Base.view(A::CuArray{T,N}, p::Dagger.Blocks{N}) where {T,N} = Dagger._view(A, p)

# In-place
# N.B. These methods assume that later operations will implicitly or
# explicitly synchronize with their associated stream
Expand Down
3 changes: 3 additions & 0 deletions ext/IntelExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ struct AllocateUndef{S} end
(::AllocateUndef{S})(T, dims::Dims{N}) where {S,N} = oneArray{S,N}(undef, dims)
Dagger.allocate_array_func(::oneArrayDeviceProc, ::Dagger.AllocateUndef{S}) where S = AllocateUndef{S}()

# Indexing
Base.getindex(arr::oneArray, d::Dagger.ArrayDomain) = arr[Dagger.indexes(d)...]

# In-place
# N.B. These methods assume that later operations will implicitly or
# explicitly synchronize with their associated stream
Expand Down
3 changes: 3 additions & 0 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ struct AllocateUndef{S} end
(::AllocateUndef{S})(T, dims::Dims{N}) where {S,N} = MtlArray{S,N}(undef, dims)
Dagger.allocate_array_func(::MtlArrayDeviceProc, ::Dagger.AllocateUndef{S}) where S = AllocateUndef{S}()

# Indexing
Base.getindex(arr::MtlArray, d::Dagger.ArrayDomain) = arr[Dagger.indexes(d)...]

# In-place
# N.B. These methods assume that later operations will implicitly or
# explicitly synchronize with their associated stream
Expand Down
3 changes: 3 additions & 0 deletions ext/OpenCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ struct AllocateUndef{S} end
(::AllocateUndef{S})(T, dims::Dims{N}) where {S,N} = CLArray{S,N}(undef, dims)
Dagger.allocate_array_func(::CLArrayDeviceProc, ::Dagger.AllocateUndef{S}) where S = AllocateUndef{S}()

# Indexing
Base.getindex(arr::CLArray, d::Dagger.ArrayDomain) = arr[Dagger.indexes(d)...]

# In-place
# N.B. These methods assume that later operations will implicitly or
# explicitly synchronize with their associated stream
Expand Down
3 changes: 3 additions & 0 deletions ext/ROCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ struct AllocateUndef{S} end
(::AllocateUndef{S})(T, dims::Dims{N}) where {S,N} = ROCArray{S,N}(undef, dims)
Dagger.allocate_array_func(::ROCArrayDeviceProc, ::Dagger.AllocateUndef{S}) where S = AllocateUndef{S}()

# Indexing
Base.getindex(arr::ROCArray, d::Dagger.ArrayDomain) = arr[Dagger.indexes(d)...]

# In-place
# N.B. These methods assume that later operations will implicitly or
# explicitly synchronize with their associated stream
Expand Down
2 changes: 1 addition & 1 deletion src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import SparseArrays: sprand, SparseMatrixCSC
import MemPool
import MemPool: DRef, FileRef, poolget, poolset

import Base: collect, reduce
import Base: collect, reduce, view

import LinearAlgebra
import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric
Expand Down
3 changes: 2 additions & 1 deletion src/array/alloc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ function Base.zero(x::DArray{T,N}) where {T,N}
return _to_darray(a)
end

function Base.view(A::AbstractArray{T,N}, p::Blocks{N}) where {T,N}
Base.view(A::AbstractArray{T,N}, p::Blocks{N}) where {T,N} = _view(A, p)
function _view(A::AbstractArray{T,N}, p::Blocks{N}) where {T,N}
d = ArrayDomain(Base.index_shape(A))
dc = partition(p, d)
# N.B. We use `tochunk` because we only want to take the view locally, and
Expand Down
97 changes: 95 additions & 2 deletions src/array/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=t
mzone = -one(T)
Ac = A.chunks
mt, nt = size(Ac)
iscomplex = T <: Complex
trans = iscomplex ? 'C' : 'T'
mb, nb = A.partitioning.blocksize

mb != nb && throw(ArgumentError("Unequal block sizes are not supported: mb = $mb, nb = $nb"))

Dagger.spawn_datadeps() do
for k in range(1, min(mt, nt))
Expand All @@ -29,3 +30,95 @@ function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.NoPivot; check::Bool=t

return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, 0)
end

function searchmax_pivot!(piv_idx::AbstractVector{Int}, piv_val::AbstractVector{T}, A::AbstractMatrix{T}, offset::Int=0) where T
max_idx = LinearAlgebra.BLAS.iamax(A[:])
piv_idx[1] = offset+max_idx
piv_val[1] = A[max_idx]
end

function update_ipiv!(ipivl::AbstractVector{Int}, piv_idx::AbstractVector{Int}, piv_val::AbstractVector{T}, k::Int, nb::Int) where T
max_piv_idx = LinearAlgebra.BLAS.iamax(piv_val)
max_piv_val = piv_val[max_piv_idx]
abs_max_piv_val = max_piv_val isa Real ? abs(max_piv_val) : abs(real(max_piv_val)) + abs(imag(max_piv_val))
isapprox(abs_max_piv_val, zero(T); atol=eps(real(T))) && throw(LinearAlgebra.SingularException(k))
ipivl[1] = (max_piv_idx+k-2)*nb + piv_idx[max_piv_idx]
end

function swaprows_panel!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipivl::AbstractVector{Int}, m::Int, p::Int, nb::Int) where T
q = div(ipivl[1]-1,nb) + 1
r = (ipivl[1]-1)%nb+1
if m == q
A[p,:], M[r,:] = M[r,:], A[p,:]
end
end

function update_panel!(M::AbstractMatrix{T}, A::AbstractMatrix{T}, p::Int) where T
Acinv = one(T) / A[p,p]
LinearAlgebra.BLAS.scal!(Acinv, view(M, :, p))
LinearAlgebra.BLAS.ger!(-one(T), view(M, :, p), conj.(view(A, p, p+1:size(A,2))), view(M, :, p+1:size(M,2)))
end

function swaprows_trail!(A::AbstractMatrix{T}, M::AbstractMatrix{T}, ipiv::AbstractVector{Int}, m::Int, nb::Int) where T
for p in eachindex(ipiv)
q = div(ipiv[p]-1,nb) + 1
r = (ipiv[p]-1)%nb+1
if m == q
A[p,:], M[r,:] = M[r,:], A[p,:]
end
end
end

function LinearAlgebra.lu(A::DMatrix{T}, ::LinearAlgebra.RowMaximum; check::Bool=true) where T
A_copy = LinearAlgebra._lucopy(A, LinearAlgebra.lutype(T))
return LinearAlgebra.lu!(A_copy, LinearAlgebra.RowMaximum(); check=check)
end
function LinearAlgebra.lu!(A::DMatrix{T}, ::LinearAlgebra.RowMaximum; check::Bool=true) where T
zone = one(T)
mzone = -one(T)

Ac = A.chunks
mt, nt = size(Ac)
m, n = size(A)
mb, nb = A.partitioning.blocksize

mb != nb && throw(ArgumentError("Unequal block sizes are not supported: mb = $mb, nb = $nb"))

ipiv = DVector(collect(1:min(m, n)), Blocks(mb))
ipivc = ipiv.chunks

max_piv_idx = zeros(Int,mt)
max_piv_val = zeros(T, mt)

Dagger.spawn_datadeps() do
for k in 1:min(mt, nt)
for p in 1:min(nb, m-(k-1)*nb, n-(k-1)*nb)
Dagger.@spawn searchmax_pivot!(Out(view(max_piv_idx, k:k)), Out(view(max_piv_val, k:k)), In(view(Ac[k,k],p:min(nb,m-(k-1)*nb),p:p)), p-1)
for i in k+1:mt
Dagger.@spawn searchmax_pivot!(Out(view(max_piv_idx, i:i)), Out(view(max_piv_val, i:i)), In(view(Ac[i,k],:,p:p)))
end
Dagger.@spawn update_ipiv!(InOut(view(ipivc[k],p:p)), In(view(max_piv_idx, k:mt)), In(view(max_piv_val, k:mt)), k, nb)
for i in k:mt
Dagger.@spawn swaprows_panel!(InOut(Ac[k, k]), InOut(Ac[i, k]), In(view(ipivc[k],p:p)), i, p, nb)
end
Dagger.@spawn update_panel!(InOut(view(Ac[k,k],p+1:min(nb,m-(k-1)*nb),:)), In(Ac[k,k]), p)
for i in k+1:mt
Dagger.@spawn update_panel!(InOut(Ac[i, k]), In(Ac[k,k]), p)
end
end
for j in Iterators.flatten((1:k-1, k+1:nt))
for i in k:mt
Dagger.@spawn swaprows_trail!(InOut(Ac[k, j]), InOut(Ac[i, j]), In(ipivc[k]), i, mb)
end
end
for j in k+1:nt
Dagger.@spawn BLAS.trsm!('L', 'L', 'N', 'U', zone, In(Ac[k, k]), InOut(Ac[k, j]))
for i in k+1:mt
Dagger.@spawn BLAS.gemm!('N', 'N', mzone, In(Ac[i, k]), In(Ac[k, j]), zone, InOut(Ac[i, j]))
end
end
end
end

return LinearAlgebra.LU{T,DMatrix{T},DVector{Int}}(A, ipiv, 0)
end
8 changes: 2 additions & 6 deletions src/array/mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,7 @@ end
m, n = size(A)
C = B'

for i = 1:m, j = 1:n
A[i, j] = C[i, j]
end
A[1:m, 1:n] .= view(C, 1:m, 1:n)
end

@inline function copydiagtile!(A, uplo)
Expand All @@ -401,7 +399,5 @@ end
C[diagind(C)] .= A[diagind(A)]
end

for i = 1:m, j = 1:n
A[i, j] = C[i, j]
end
A[1:m, 1:n] .= view(C, 1:m, 1:n)
end
97 changes: 54 additions & 43 deletions test/array/linalg/cholesky.jl
Original file line number Diff line number Diff line change
@@ -1,49 +1,60 @@
@testset "$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
D = rand(Blocks(4, 4), T, 32, 32)
if !(T <: Complex)
@test !issymmetric(D)
end
@test !ishermitian(D)
function test_cholesky(AT)
@testset "$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
D = rand(Blocks(4, 4), T, 32, 32)
if !(T <: Complex)
@test !issymmetric(D)
end
@test !ishermitian(D)

A = rand(T, 128, 128)
A = A * A'
A[diagind(A)] .+= size(A, 1)
B = copy(A)
DA = view(A, Blocks(32, 32))
if !(T <: Complex)
@test issymmetric(DA)
end
@test ishermitian(DA)
A = AT(rand(T, 128, 128))
A = A * A'
A[diagind(A)] .+= size(A, 1)
B = copy(A)
DA = view(A, Blocks(32, 32))
if !(T <: Complex)
@test issymmetric(DA)
end
@test ishermitian(DA)

# Out-of-place
chol_A = cholesky(A)
chol_DA = cholesky(DA)
@test chol_DA isa Cholesky
@test chol_A.L ≈ chol_DA.L
@test chol_A.U ≈ chol_DA.U
# Check that cholesky did not modify A or DA
@test A ≈ DA ≈ B
# Out-of-place
chol_A = cholesky(A)
chol_DA = cholesky(DA)
@test chol_DA isa Cholesky
@test chol_A.L ≈ chol_DA.L
@test chol_A.U ≈ chol_DA.U
# Check that cholesky did not modify A or DA
@test A ≈ DA ≈ B

# In-place
A_copy = copy(A)
chol_A = cholesky!(A_copy)
chol_DA = cholesky!(DA)
@test chol_DA isa Cholesky
@test chol_A.L ≈ chol_DA.L
@test chol_A.U ≈ chol_DA.U
# Check that changes propagated to A
@test UpperTriangular(collect(DA)) ≈ UpperTriangular(collect(A))
# In-place
A_copy = copy(A)
chol_A = cholesky!(A_copy)
chol_DA = cholesky!(DA)
@test chol_DA isa Cholesky
@test chol_A.L ≈ chol_DA.L
@test chol_A.U ≈ chol_DA.U
# Check that changes propagated to A
@test UpperTriangular(collect(DA)) ≈ UpperTriangular(collect(A))

# Non-PosDef matrix
A = rand(T, 128, 128)
A = A * A'
A[diagind(A)] .+= size(A, 1)
A[1, 1] = -100
DA = view(A, Blocks(32, 32))
if !(T <: Complex)
@test issymmetric(DA)
# Non-PosDef matrix
A = AT(rand(T, 128, 128))
A = A * A'
A[diagind(A)] .+= size(A, 1)
A[1, 1] = -100
DA = view(A, Blocks(32, 32))
if !(T <: Complex)
@test issymmetric(DA)
end
@test ishermitian(DA)
@test_broken cholesky(DA).U == 42 # This should throw PosDefException
#@test_throws_unwrap PosDefException cholesky(DA).U
end
@test ishermitian(DA)
@test_broken cholesky(DA).U == 42 # This should throw PosDefException
#@test_throws_unwrap PosDefException cholesky(DA).U
end

for (kind, AT, scope) in ALL_SCOPES
kind == :oneAPI || kind == :Metal || kind == :OpenCL && continue
@testset "$kind" begin
Dagger.with_options(;scope) do
test_cholesky(AT)
end
end
end
27 changes: 19 additions & 8 deletions test/array/linalg/core.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
@testset "isapprox" begin
A = rand(16, 16)
function test_linalg_core(AT)
@testset "isapprox" begin
A = AT(rand(16, 16))

U1 = UpperTriangular(DArray(A, Blocks(16, 16)))
U2 = UpperTriangular(DArray(A, Blocks(16, 16)))
@test isapprox(U1, U2)
U1 = UpperTriangular(DArray(A, Blocks(16, 16)))
U2 = UpperTriangular(DArray(A, Blocks(16, 16)))
@test isapprox(U1, U2)

L1 = LowerTriangular(DArray(A, Blocks(16, 16)))
L2 = LowerTriangular(DArray(A, Blocks(16, 16)))
@test isapprox(L1, L2)
L1 = LowerTriangular(DArray(A, Blocks(16, 16)))
L2 = LowerTriangular(DArray(A, Blocks(16, 16)))
@test isapprox(L1, L2)
end
end

for (kind, AT, scope) in ALL_SCOPES
kind == :oneAPI || kind == :Metal || kind == :OpenCL && continue
@testset "$kind" begin
Dagger.with_options(;scope) do
test_linalg_core(AT)
end
end
end
Loading