Skip to content

Make sparsearrays an ext #448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ version = "7.12.0"
[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
Expand All @@ -16,6 +14,7 @@ CUDSS = "45b445bb-4962-46a0-9369-b4df9d0f772e"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

Expand All @@ -27,6 +26,7 @@ ArrayInterfaceCUDSSExt = "CUDSS"
ArrayInterfaceChainRulesExt = "ChainRules"
ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore"
ArrayInterfaceReverseDiffExt = "ReverseDiff"
ArrayInterfaceSparseArraysExt = "SparseArrays"
ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore"
ArrayInterfaceTrackerExt = "Tracker"

Expand All @@ -42,7 +42,6 @@ LinearAlgebra = "1.10"
ReverseDiff = "1"
SparseArrays = "1.10"
StaticArraysCore = "1"
SuiteSparse = "1.10"
Tracker = "0.2"
julia = "1.10"

Expand Down
38 changes: 38 additions & 0 deletions ext/ArrayInterfaceSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
module ArrayInterfaceSparseArraysExt

import ArrayInterface: buffer, has_sparsestruct, issingular, findstructralnz, bunchkaufman_instance, DEFAULT_CHOLESKY_PIVOT, cholesky_instance, ldlt_instance, lu_instance, qr_instance
using ArrayInterface.LinearAlgebra
using SparseArrays

buffer(x::SparseMatrixCSC) = getfield(x, :nzval)
buffer(x::SparseVector) = getfield(x, :nzval)
has_sparsestruct(::Type{<:SparseMatrixCSC}) = true
issingular(A::AbstractSparseMatrix) = !issuccess(lu(A, check = false))

function findstructralnz(x::SparseMatrixCSC)
rowind, colind, _ = findnz(x)
(rowind, colind)
end

function bunchkaufman_instance(A::SparseMatrixCSC)
bunchkaufman(sparse(similar(A, 1, 1)), check = false)
end

function cholesky_instance(A::Union{SparseMatrixCSC,Symmetric{<:Number,<:SparseMatrixCSC}}, pivot = DEFAULT_CHOLESKY_PIVOT)
cholesky(sparse(similar(A, 1, 1)), check = false)
end

function ldlt_instance(A::SparseMatrixCSC)
ldlt(sparse(similar(A, 1, 1)), check=false)
end

# Could be optimized but this should work for any real case.
function lu_instance(jac_prototype::SparseMatrixCSC, pivot = DEFAULT_CHOLESKY_PIVOT)
lu(sparse(rand(1,1)))
end

function qr_instance(jac_prototype::SparseMatrixCSC, pivot = DEFAULT_CHOLESKY_PIVOT)
qr(sparse(rand(1,1)))
end

end
34 changes: 2 additions & 32 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module ArrayInterface

using LinearAlgebra
using SparseArrays
using SuiteSparse

@static if isdefined(Base, Symbol("@assume_effects"))
using Base: @assume_effects
Expand Down Expand Up @@ -121,8 +119,6 @@ Return the buffer data that `x` points to. Unlike `parent(x::AbstractArray)`, `b
may not return another array type.
"""
buffer(x) = parent(x)
buffer(x::SparseMatrixCSC) = getfield(x, :nzval)
buffer(x::SparseVector) = getfield(x, :nzval)
buffer(@nospecialize x::Union{Base.Slice, Base.IdentityUnitRange}) = getfield(x, :indices)

"""
Expand Down Expand Up @@ -308,7 +304,6 @@ Determine whether `findstructralnz` accepts the parameter `x`.
has_sparsestruct(x) = has_sparsestruct(typeof(x))
has_sparsestruct(::Type) = false
has_sparsestruct(::Type{<:AbstractArray}) = false
has_sparsestruct(::Type{<:SparseMatrixCSC}) = true
has_sparsestruct(::Type{<:Diagonal}) = true
has_sparsestruct(::Type{<:Bidiagonal}) = true
has_sparsestruct(::Type{<:Tridiagonal}) = true
Expand All @@ -320,7 +315,6 @@ has_sparsestruct(::Type{<:SymTridiagonal}) = true
Determine whether a given abstract matrix is singular.
"""
issingular(A::AbstractMatrix) = issingular(Matrix(A))
issingular(A::AbstractSparseMatrix) = !issuccess(lu(A, check = false))
issingular(A::Matrix) = !issuccess(lu(A, check = false))
issingular(A::UniformScaling) = A.λ == 0
issingular(A::Diagonal) = any(iszero, A.diag)
Expand Down Expand Up @@ -359,11 +353,6 @@ function findstructralnz(x::Union{Tridiagonal, SymTridiagonal})
(rowind, colind)
end

function findstructralnz(x::SparseMatrixCSC)
rowind, colind, _ = findnz(x)
(rowind, colind)
end

abstract type ColoringAlgorithm end

"""
Expand Down Expand Up @@ -403,9 +392,6 @@ cheaply.
function bunchkaufman_instance(A::Matrix{T}) where T
return bunchkaufman(similar(A, 0, 0), check = false)
end
function bunchkaufman_instance(A::SparseMatrixCSC)
bunchkaufman(sparse(similar(A, 1, 1)), check = false)
end

"""
bunchkaufman_instance(a::Number) -> a
Expand All @@ -429,14 +415,10 @@ cholesky_instance(A, pivot = LinearAlgebra.RowMaximum()) -> cholesky_factorizati
Returns an instance of the Cholesky factorization object with the correct type
cheaply.
"""
function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T}
function cholesky_instance(A::Matrix{T}, pivot = DEFAULT_CHOLESKY_PIVOT) where {T}
return cholesky(similar(A, 0, 0), pivot, check = false)
end

function cholesky_instance(A::Union{SparseMatrixCSC,Symmetric{<:Number,<:SparseMatrixCSC}}, pivot = DEFAULT_CHOLESKY_PIVOT)
cholesky(sparse(similar(A, 1, 1)), check = false)
end

"""
cholesky_instance(a::Number, pivot = LinearAlgebra.RowMaximum()) -> a

Expand All @@ -458,14 +440,10 @@ ldlt_instance(A) -> ldlt_factorization_instance
Returns an instance of the LDLT factorization object with the correct type
cheaply.
"""
function ldlt_instance(A::Matrix{T}) where {T}
function ldlt_instance(A::Matrix{T}) where {T}
return ldlt_instance(SymTridiagonal(similar(A, 0, 0)))
end

function ldlt_instance(A::SparseMatrixCSC)
ldlt(sparse(similar(A, 1, 1)), check=false)
end

function ldlt_instance(A::SymTridiagonal{T,V}) where {T,V}
return LinearAlgebra.LDLt{T,SymTridiagonal{T,V}}(A)
end
Expand Down Expand Up @@ -498,9 +476,6 @@ function lu_instance(A::Matrix{T}) where {T}
info = zero(LinearAlgebra.BlasInt)
return LU{luT}(similar(A, 0, 0), ipiv, info)
end
function lu_instance(jac_prototype::SparseMatrixCSC)
SuiteSparse.UMFPACK.UmfpackLU(similar(jac_prototype, 1, 1))
end

function lu_instance(A::Symmetric{T}) where {T}
noUnitT = typeof(zero(T))
Expand Down Expand Up @@ -557,11 +532,6 @@ function qr_instance(A::Matrix{BigFloat},pivot = DEFAULT_CHOLESKY_PIVOT)
LinearAlgebra.QR(zeros(BigFloat,0,0),zeros(BigFloat,0))
end

# Could be optimized but this should work for any real case.
function qr_instance(jac_prototype::SparseMatrixCSC, pivot = DEFAULT_CHOLESKY_PIVOT)
qr(sparse(rand(1,1)))
end

"""
qr_instance(a::Number) -> a

Expand Down
Loading