Skip to content

Commit 2e2c16a

Browse files
authored
Fix @simd for non 1 step CartesianPartition (#42736)
1 parent 9af12d3 commit 2e2c16a

File tree

3 files changed

+48
-35
lines changed

3 files changed

+48
-35
lines changed

base/broadcast.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -973,14 +973,14 @@ end
973973
destc = dest.chunks
974974
cind = 1
975975
bc′ = preprocess(dest, bc)
976-
for P in Iterators.partition(eachindex(bc′), bitcache_size)
976+
@inbounds for P in Iterators.partition(eachindex(bc′), bitcache_size)
977977
ind = 1
978978
@simd for I in P
979-
@inbounds tmp[ind] = bc′[I]
979+
tmp[ind] = bc′[I]
980980
ind += 1
981981
end
982982
@simd for i in ind:bitcache_size
983-
@inbounds tmp[i] = false
983+
tmp[i] = false
984984
end
985985
dumpbitcache(destc, cind, tmp)
986986
cind += bitcache_chunks

base/multidimensional.jl

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,8 @@ module IteratorsMD
477477
simd_inner_length(iter::CartesianIndices, I::CartesianIndex) = Base.length(iter.indices[1])
478478

479479
simd_index(iter::CartesianIndices{0}, ::CartesianIndex, I1::Int) = first(iter)
480-
@propagate_inbounds function simd_index(iter::CartesianIndices, Ilast::CartesianIndex, I1::Int)
481-
CartesianIndex(getindex(iter.indices[1], I1+first(Base.axes1(iter.indices[1]))), Ilast.I...)
482-
end
480+
@propagate_inbounds simd_index(iter::CartesianIndices, Ilast::CartesianIndex, I1::Int) =
481+
CartesianIndex(iter.indices[1][I1+firstindex(iter.indices[1])], Ilast)
483482

484483
# Split out the first N elements of a tuple
485484
@inline function split(t, V::Val)
@@ -585,7 +584,7 @@ module IteratorsMD
585584
CartesianIndices(intersect.(a.indices, b.indices))
586585

587586
# Views of reshaped CartesianIndices are used for partitions — ensure these are fast
588-
const CartesianPartition{T<:CartesianIndex, P<:CartesianIndices, R<:ReshapedArray{T,1,P}} = SubArray{T,1,R,Tuple{UnitRange{Int}},false}
587+
const CartesianPartition{T<:CartesianIndex, P<:CartesianIndices, R<:ReshapedArray{T,1,P}} = SubArray{T,1,R,<:Tuple{AbstractUnitRange{Int}},false}
589588
eltype(::Type{PartitionIterator{T}}) where {T<:ReshapedArrayLF} = SubArray{eltype(T), 1, T, Tuple{UnitRange{Int}}, true}
590589
eltype(::Type{PartitionIterator{T}}) where {T<:ReshapedArray} = SubArray{eltype(T), 1, T, Tuple{UnitRange{Int}}, false}
591590
Iterators.IteratorEltype(::Type{<:PartitionIterator{T}}) where {T<:ReshapedArray} = Iterators.IteratorEltype(T)
@@ -594,7 +593,6 @@ module IteratorsMD
594593
eltype(::Type{PartitionIterator{T}}) where {T<:Union{UnitRange, StepRange, StepRangeLen, LinRange}} = T
595594
Iterators.IteratorEltype(::Type{<:PartitionIterator{T}}) where {T<:Union{OneTo, UnitRange, StepRange, StepRangeLen, LinRange}} = Iterators.IteratorEltype(T)
596595

597-
598596
@inline function iterate(iter::CartesianPartition)
599597
isempty(iter) && return nothing
600598
f = first(iter)
@@ -610,33 +608,45 @@ module IteratorsMD
610608
# In general, the Cartesian Partition might start and stop in the middle of the outer
611609
# dimensions — thus the outer range of a CartesianPartition is itself a
612610
# CartesianPartition.
613-
t = tail(iter.parent.parent.indices)
614-
ci = CartesianIndices(t)
615-
li = LinearIndices(t)
616-
return @inbounds view(ci, li[tail(iter[1].I)...]:li[tail(iter[end].I)...])
611+
mi = iter.parent.mi
612+
ci = iter.parent.parent
613+
ax, ax1 = axes(ci), Base.axes1(ci)
614+
subs = Base.ind2sub_rs(ax, mi, first(iter.indices[1]))
615+
vl, fl = Base._sub2ind(tail(ax), tail(subs)...), subs[1]
616+
vr, fr = divrem(last(iter.indices[1]) - 1, mi[end]) .+ (1, first(ax1))
617+
oci = CartesianIndices(tail(ci.indices))
618+
# A fake CartesianPartition to reuse the outer iterate fallback
619+
outer = @inbounds view(ReshapedArray(oci, (length(oci),), mi), vl:vr)
620+
init = @inbounds dec(oci[tail(subs)...].I, oci.indices) # real init state
621+
# Use Generator to make inner loop branchless
622+
@inline function skip_len_I(i::Int, I::CartesianIndex)
623+
l = i == 1 ? fl : first(ax1)
624+
r = i == length(outer) ? fr : last(ax1)
625+
l - first(ax1), r - l + 1, I
626+
end
627+
(skip_len_I(i, I) for (i, I) in Iterators.enumerate(Iterators.rest(outer, (init, 0))))
617628
end
618-
function simd_outer_range(iter::CartesianPartition{CartesianIndex{2}})
629+
@inline function simd_outer_range(iter::CartesianPartition{CartesianIndex{2}})
619630
# But for two-dimensional Partitions the above is just a simple one-dimensional range
620631
# over the second dimension; we don't need to worry about non-rectangular staggers in
621632
# higher dimensions.
622-
return @inbounds CartesianIndices((iter[1][2]:iter[end][2],))
623-
end
624-
@inline function simd_inner_length(iter::CartesianPartition, I::CartesianIndex)
625-
inner = iter.parent.parent.indices[1]
626-
@inbounds fi = iter[1].I
627-
@inbounds li = iter[end].I
628-
inner_start = I.I == tail(fi) ? fi[1] : first(inner)
629-
inner_end = I.I == tail(li) ? li[1] : last(inner)
630-
return inner_end - inner_start + 1
631-
end
632-
@inline function simd_index(iter::CartesianPartition, Ilast::CartesianIndex, I1::Int)
633-
# I1 is the 0-based distance from the first dimension's offest
634-
offset = first(iter.parent.parent.indices[1]) # (this is 1 for 1-based arrays)
635-
# In the first column we need to also add in the iter's starting point (branchlessly)
636-
f = @inbounds iter[1]
637-
startoffset = (Ilast.I == tail(f.I))*(f[1] - 1)
638-
CartesianIndex((I1 + offset + startoffset, Ilast.I...))
633+
mi = iter.parent.mi
634+
ci = iter.parent.parent
635+
ax, ax1 = axes(ci), Base.axes1(ci)
636+
fl, vl = Base.ind2sub_rs(ax, mi, first(iter.indices[1]))
637+
fr, vr = Base.ind2sub_rs(ax, mi, last(iter.indices[1]))
638+
outer = @inbounds CartesianIndices((ci.indices[2][vl:vr],))
639+
# Use Generator to make inner loop branchless
640+
@inline function skip_len_I(I::CartesianIndex{1})
641+
l = I == first(outer) ? fl : first(ax1)
642+
r = I == last(outer) ? fr : last(ax1)
643+
l - first(ax1), r - l + 1, I
644+
end
645+
(skip_len_I(I) for I in outer)
639646
end
647+
@inline simd_inner_length(iter::CartesianPartition, (_, len, _)::Tuple{Int,Int,CartesianIndex}) = len
648+
@propagate_inbounds simd_index(iter::CartesianPartition, (skip, _, I)::Tuple{Int,Int,CartesianIndex}, n::Int) =
649+
simd_index(iter.parent.parent, I, n + skip)
640650
end # IteratorsMD
641651

642652

test/iterators.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -550,12 +550,15 @@ end
550550
(1,1), (8,8), (11, 13),
551551
(1,1,1), (8, 4, 2), (11, 13, 17)),
552552
part in (1, 7, 8, 11, 63, 64, 65, 142, 143, 144)
553-
P = partition(CartesianIndices(dims), part)
554-
for I in P
555-
@test length(I) == iterate_length(I) == simd_iterate_length(I) == simd_trip_count(I)
556-
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
553+
for fun in (i -> 1:i, i -> 1:2:2i, i -> Base.IdentityUnitRange(-i:i))
554+
iter = CartesianIndices(map(fun, dims))
555+
P = partition(iter, part)
556+
for I in P
557+
@test length(I) == iterate_length(I) == simd_iterate_length(I) == simd_trip_count(I)
558+
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
559+
end
560+
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), iter))
557561
end
558-
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), CartesianIndices(dims)))
559562
end
560563
@testset "empty/invalid partitions" begin
561564
@test_throws ArgumentError partition(1:10, 0)

0 commit comments

Comments
 (0)