Skip to content

Commit 78d55e2

Browse files
authored
Use lispy tuples in cat (fixes #21673) (#39314)
The `cat` pipeline has long had poor inferrability. Together with #39292 and #39294, this should basically put an end to that problem. Together, at least in simple cases these make the performance of `cat` essentially equivalent to the manual version. In other words, the `test1` and `test2` of #21673 benchmark very similarly.
1 parent 6813340 commit 78d55e2

File tree

1 file changed

+22
-21
lines changed

1 file changed

+22
-21
lines changed

base/abstractarray.jl

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1645,28 +1645,29 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
16451645
return __cat(A, shape, catdims, X...)
16461646
end
16471647

1648-
function __cat(A, shape::NTuple{M}, catdims, X...) where M
1649-
N = M::Int
1650-
offsets = zeros(Int, N)
1651-
inds = Vector{UnitRange{Int}}(undef, N)
1652-
concat = copyto!(zeros(Bool, N), catdims)
1653-
for x in X
1654-
for i = 1:N
1655-
if concat[i]
1656-
inds[i] = offsets[i] .+ cat_indices(x, i)
1657-
offsets[i] += cat_size(x, i)
1658-
else
1659-
inds[i] = 1:shape[i]
1660-
end
1661-
end
1662-
I::NTuple{N, UnitRange{Int}} = (inds...,)
1663-
if x isa AbstractArray
1664-
A[I...] = x
1665-
else
1666-
fill!(view(A, I...), x)
1667-
end
1648+
# Why isn't this called `__cat!`?
1649+
__cat(A, shape, catdims, X...) = __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...)
1650+
1651+
function __cat_offset!(A, shape, catdims, offsets, x, X...)
1652+
# splitting the "work" on x from X... may reduce latency (fewer costly specializations)
1653+
newoffsets = __cat_offset1!(A, shape, catdims, offsets, x)
1654+
return __cat_offset!(A, shape, catdims, newoffsets, X...)
1655+
end
1656+
__cat_offset!(A, shape, catdims, offsets) = A
1657+
1658+
function __cat_offset1!(A, shape, catdims, offsets, x)
1659+
inds = ntuple(length(offsets)) do i
1660+
(i <= length(catdims) && catdims[i]) ? offsets[i] .+ cat_indices(x, i) : 1:shape[i]
1661+
end
1662+
if x isa AbstractArray
1663+
A[inds...] = x
1664+
else
1665+
fill!(view(A, inds...), x)
1666+
end
1667+
newoffsets = ntuple(length(offsets)) do i
1668+
(i <= length(catdims) && catdims[i]) ? offsets[i] + cat_size(x, i) : offsets[i]
16681669
end
1669-
return A
1670+
return newoffsets
16701671
end
16711672

16721673
"""

0 commit comments

Comments
 (0)