Skip to content

Commit a690758

Browse files
authored
Eagerly evaluate indices in eachindex check (#58054)
This reduces TTFX in `eachindex` calls with multiple arguments, with minimal impact on performance. Much of the latency comes from the error path, and specializing it for the common case of 2 arguments helps a lot with reducing latency. In this, I've also unrolled the `join` in the error path, and we recursively generate a `LazyString`s instead. This helps in reducing TTFX for a longer list of arguments. ```julia julia> a = zeros(2,2); julia> @time eachindex(a, a); 0.046902 seconds (128.39 k allocations: 6.652 MiB, 99.93% compilation time) # nightly 0.015368 seconds (19.91 k allocations: 1.048 MiB, 99.79% compilation time) # this PR julia> @Btime eachindex($a, $a, $a, $a, $a, $a, $a, $a); 6.945 ns (0 allocations: 0 bytes) # nightly 6.855 ns (0 allocations: 0 bytes) # this PR ``` This reduces TTFX for a longer list of arguments as well: ```julia julia> @time eachindex(a, a, a, a, a, a, a, a); 0.052552 seconds (196.87 k allocations: 10.068 MiB, 99.53% compilation time) # nightly 0.043401 seconds (69.13 k allocations: 3.454 MiB, 99.34% compilation time) # this PR ``` For Cartesian indexing, ```julia julia> a = zeros(2,2); julia> v = view(a, 1:2, 1:2); julia> @time eachindex(a, v); 0.051333 seconds (171.34 k allocations: 8.921 MiB, 99.94% compilation time) # nightly 0.016340 seconds (26.95 k allocations: 1.405 MiB, 99.79% compilation time) # this PR julia> @Btime eachindex($a, $v, $a, $v, $a, $v, $a, $v); 9.339 ns (0 allocations: 0 bytes) # nightly 9.357 ns (0 allocations: 0 bytes) # this PR ```
1 parent bd193e4 commit a690758

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

base/abstractarray.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,13 @@ eachindex(itrs...) = keys(itrs...)
321321
eachindex(A::AbstractVector) = (@inline(); axes1(A))
322322

323323

324-
@noinline function throw_eachindex_mismatch_indices(::IndexLinear, inds...)
325-
throw(DimensionMismatch("all inputs to eachindex must have the same indices, got $(join(inds, ", ", " and "))"))
326-
end
327-
@noinline function throw_eachindex_mismatch_indices(::IndexCartesian, inds...)
328-
throw(DimensionMismatch("all inputs to eachindex must have the same axes, got $(join(inds, ", ", " and "))"))
324+
# we unroll the join for easier inference
325+
_join_comma_and(indsA, indsB) = LazyString(indsA, " and ", indsB)
326+
_join_comma_and(indsA, indsB, indsC...) = LazyString(indsA, ", ", _join_comma_and(indsB, indsC...))
327+
@noinline function throw_eachindex_mismatch_indices(indices_str, indsA, indsBs...)
328+
throw(DimensionMismatch(
329+
LazyString("all inputs to eachindex must have the same ", indices_str, ", got ",
330+
_join_comma_and(indsA, indsBs...))))
329331
end
330332

331333
"""
@@ -390,15 +392,11 @@ eachindex(::IndexLinear, A::AbstractVector) = (@inline; axes1(A))
390392
function eachindex(::IndexLinear, A::AbstractArray, B::AbstractArray...)
391393
@inline
392394
indsA = eachindex(IndexLinear(), A)
393-
_all_match_first(X->eachindex(IndexLinear(), X), indsA, B...) ||
394-
throw_eachindex_mismatch_indices(IndexLinear(), eachindex(A), eachindex.(B)...)
395+
indsBs = map(X -> eachindex(IndexLinear(), X), B)
396+
all(==(indsA), indsBs) ||
397+
throw_eachindex_mismatch_indices("indices", indsA, indsBs...)
395398
indsA
396399
end
397-
function _all_match_first(f::F, inds, A, B...) where F<:Function
398-
@inline
399-
(inds == f(A)) & _all_match_first(f, inds, B...)
400-
end
401-
_all_match_first(f::F, inds) where F<:Function = true
402400

403401
# keys with an IndexStyle
404402
keys(s::IndexStyle, A::AbstractArray, B::AbstractArray...) = eachindex(s, A, B...)

base/multidimensional.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,9 @@ module IteratorsMD
416416

417417
@inline function eachindex(::IndexCartesian, A::AbstractArray, B::AbstractArray...)
418418
axsA = axes(A)
419-
Base._all_match_first(axes, axsA, B...) || Base.throw_eachindex_mismatch_indices(IndexCartesian(), axes(A), axes.(B)...)
419+
axsBs = map(axes, B)
420+
all(==(axsA), axsBs) ||
421+
Base.throw_eachindex_mismatch_indices("axes", axsA, axsBs...)
420422
CartesianIndices(axsA)
421423
end
422424

base/reinterpretarray.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ SCartesianIndices2{K}(indices2::AbstractUnitRange{Int}) where {K} = (@assert K::
271271
eachindex(::IndexSCartesian2{K}, A::ReshapedReinterpretArray) where {K} = SCartesianIndices2{K}(eachindex(IndexLinear(), parent(A)))
272272
@inline function eachindex(style::IndexSCartesian2{K}, A::AbstractArray, B::AbstractArray...) where {K}
273273
iter = eachindex(style, A)
274-
_all_match_first(C->eachindex(style, C), iter, B...) || throw_eachindex_mismatch_indices(IndexSCartesian2{K}(), axes(A), axes.(B)...)
274+
itersBs = map(C->eachindex(style, C), B)
275+
all(==(iter), itersBs) || throw_eachindex_mismatch_indices("axes", axes(A), map(axes, B)...)
275276
return iter
276277
end
277278

0 commit comments

Comments
 (0)