Skip to content

Commit 1a302b9

Browse files
committed
Make sparse hvcat inferable
This also requires making sparse `vcat` and `hcat` inferable in the vararg case which in turn requires a different way to determine the resulting index type, now implemented similar to `promote_eltype`.
1 parent e55ee1d commit 1a302b9

File tree

3 files changed

+16
-13
lines changed

3 files changed

+16
-13
lines changed

stdlib/SparseArrays/src/sparsematrix.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3272,6 +3272,10 @@ dropstored!(A::AbstractSparseMatrixCSC, ::Colon) = dropstored!(A, :, :)
32723272

32733273
# Sparse concatenation
32743274

3275+
promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}) where {Ti} = Ti
3276+
promote_idxtype(::AbstractSparseMatrixCSC{<:Any, Ti}, X::AbstractSparseMatrixCSC...) where {Ti} =
3277+
promote_type(Ti, promote_idxtype(X...))
3278+
32753279
function vcat(X::AbstractSparseMatrixCSC...)
32763280
num = length(X)
32773281
mX = Int[ size(x, 1) for x in X ]
@@ -3286,7 +3290,7 @@ function vcat(X::AbstractSparseMatrixCSC...)
32863290
end
32873291

32883292
Tv = promote_eltype(X...)
3289-
Ti = promote_eltype(map(x->rowvals(x), X)...)
3293+
Ti = promote_idxtype(X...)
32903294

32913295
nnzX = Int[ nnz(x) for x in X ]
32923296
nnz_res = sum(nnzX)
@@ -3338,7 +3342,7 @@ function hcat(X::AbstractSparseMatrixCSC...)
33383342
n = sum(nX)
33393343

33403344
Tv = promote_eltype(X...)
3341-
Ti = promote_eltype(map(x->rowvals(x), X)...)
3345+
Ti = promote_idxtype(X...)
33423346

33433347
colptr = Vector{Ti}(undef, n+1)
33443348
nnzX = Int[ nnz(x) for x in X ]

stdlib/SparseArrays/src/sparsevector.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,17 +1091,16 @@ function vcat(Xin::_SparseConcatGroup...)
10911091
X = map(x -> SparseMatrixCSC(issparse(x) ? x : sparse(x)), Xin)
10921092
vcat(X...)
10931093
end
1094-
function hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
1095-
nbr = length(rows) # number of block rows
1096-
1097-
tmp_rows = Vector{SparseMatrixCSC}(undef, nbr)
1098-
k = 0
1099-
@inbounds for i = 1 : nbr
1100-
tmp_rows[i] = hcat(X[(1 : rows[i]) .+ k]...)
1101-
k += rows[i]
1094+
hvcat(rows::Tuple{Vararg{Int}}, X::_SparseConcatGroup...) =
1095+
vcat(_hvcat_rows(rows, X...)...)
1096+
function _hvcat_rows((row1, rows...)::Tuple{Vararg{Int}}, X::_SparseConcatGroup...)
1097+
if row1 0
1098+
throw(ArgumentError("length of block row must be positive, got $row1"))
11021099
end
1103-
vcat(tmp_rows...)
1100+
# provide X[1] separately to convince inference that we don't call hcat() without arguments
1101+
return (hcat(X[1], X[2 : row1]...), _hvcat_rows(rows, X[row1+1:end]...)...)
11041102
end
1103+
_hvcat_rows(::Tuple{}, X::_SparseConcatGroup...) = ()
11051104

11061105
# make sure UniformScaling objects are converted to sparse matrices for concatenation
11071106
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = SparseMatrixCSC

stdlib/SparseArrays/test/sparse.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ end
162162
sz34 = spzeros(3, 4)
163163
se77 = sparse(1.0I, 7, 7)
164164
@testset "h+v concatenation" begin
165-
@test [se44 sz42 sz41; sz34 se33] == se77
165+
@test @inferred(hvcat((3, 2), se44, sz42, sz41, sz34, se33)) == se77 # [se44 sz42 sz41; sz34 se33]
166166
@test length(nonzeros([sp33 0I; 1I 0I])) == 6
167167
end
168168

@@ -2168,7 +2168,7 @@ end
21682168
# Test that concatenations of pairs of sparse matrices yield sparse arrays
21692169
@test issparse(vcat(spmat, spmat))
21702170
@test issparse(hcat(spmat, spmat))
2171-
@test issparse(hvcat((2,), spmat, spmat))
2171+
@test issparse(@inferred(hvcat((2,), spmat, spmat)))
21722172
@test issparse(cat(spmat, spmat; dims=(1,2)))
21732173
# Test that concatenations of a sparse matrice with a dense matrix/vector yield sparse arrays
21742174
@test issparse(vcat(spmat, densemat))

0 commit comments

Comments
 (0)