Skip to content

Commit 815076b

Browse files
authored
Improve inferability of shape::Dims for cat (#39294)
`cat` is often called with Varargs or heterogenous inputs, and inference almost always fails. Even when all the arrays are of the same type, if the number of varargs isn't known inference typically fails. The culprit is probably #36454. This reduces the number of failures considerably, by avoiding creation of vararg length tuples in the shape-inference pipeline.
1 parent fb39bdb commit 815076b

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

base/abstractarray.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,7 @@ cat_indices(A::AbstractArray, d) = axes(A, d)
15801580
cat_similar(A, ::Type{T}, shape) where T = Array{T}(undef, shape)
15811581
cat_similar(A::AbstractArray, ::Type{T}, shape) where T = similar(A, T, shape)
15821582

1583+
# These are for backwards compatibility (even though internal)
15831584
cat_shape(dims, shape::Tuple{Vararg{Int}}) = shape
15841585
function cat_shape(dims, shapes::Tuple)
15851586
out_shape = ()
@@ -1588,6 +1589,11 @@ function cat_shape(dims, shapes::Tuple)
15881589
end
15891590
return out_shape
15901591
end
1592+
# The new way to compute the shape (more inferrable than combining cat_size & cat_shape, due to Varargs + issue#36454)
1593+
cat_size_shape(dims) = ntuple(zero, Val(length(dims)))
1594+
@inline cat_size_shape(dims, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, (), cat_size(X)), tail...)
1595+
_cat_size_shape(dims, shape) = shape
1596+
@inline _cat_size_shape(dims, shape, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, shape, cat_size(X)), tail...)
15911597

15921598
_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
15931599
_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, nshape) = nshape
@@ -1631,7 +1637,7 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
16311637
@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...)
16321638
@inline function _cat_t(dims, ::Type{T}, X...) where {T}
16331639
catdims = dims2cat(dims)
1634-
shape = cat_shape(catdims, map(cat_size, X))
1640+
shape = cat_size_shape(catdims, X...)
16351641
A = cat_similar(X[1], T, shape)
16361642
if count(!iszero, catdims)::Int > 1
16371643
fill!(A, zero(T))

test/abstractarray.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,12 @@ function test_cat(::Type{TestAbstractArray})
692692
# 36041
693693
@test_throws MethodError cat(["a"], ["b"], dims=[1, 2])
694694
@test cat([1], [1], dims=[1, 2]) == I(2)
695+
696+
# inferrability
697+
As = [zeros(2, 2) for _ = 1:2]
698+
@test @inferred(cat(As...; dims=Val(3))) == zeros(2, 2, 2)
699+
cat3v(As) = cat(As...; dims=Val(3))
700+
@test @inferred(cat3v(As)) == zeros(2, 2, 2)
695701
end
696702

697703
function test_ind2sub(::Type{TestAbstractArray})

0 commit comments

Comments
 (0)