Skip to content

Commit 69a82b5

Browse files
committed
Quasi-linear indexing for reshaped reinterpretarray
1 parent bc2abbe commit 69a82b5

File tree

3 files changed

+299
-45
lines changed

3 files changed

+299
-45
lines changed

base/reinterpretarray.jl

Lines changed: 183 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,96 @@ function check_writable(a::ReinterpretArray{T, N, S} where N) where {T,S}
175175
end
176176
end
177177

178-
IndexStyle(a::ReinterpretArray) = IndexStyle(a.parent)
179-
IndexStyle(a::ReshapedReinterpretArray{T, N, S}) where {T, N, S} = sizeof(T) < sizeof(S) ? IndexCartesian() : IndexStyle(a.parent)
178+
## IndexStyle specializations
179+
180+
# For `reinterpret(reshape, T, a)` where we're adding a channel dimension and with
181+
# `IndexStyle(a) == IndexLinear()`, it's advantageous to retain pseudo-linear indexing.
182+
struct IndexSCartesian2{K,N} <: IndexStyle end # K = sizeof(S) ÷ sizeof(T), a static-sized 2d cartesian iterator
183+
184+
IndexStyle(::Type{ReinterpretArray{T,N,S,A,false}}) where {T,N,S,A<:AbstractArray{S,N}} = IndexStyle(A)
185+
function IndexStyle(::Type{ReinterpretArray{T,N,S,A,true}}) where {T,N,S,A<:AbstractArray{S}}
186+
if sizeof(T) < sizeof(S)
187+
IndexStyle(A) === IndexLinear() && return IndexSCartesian2{sizeof(S) ÷ sizeof(T),N}()
188+
return IndexCartesian()
189+
end
190+
return IndexStyle(A)
191+
end
192+
IndexStyle(::IndexSCartesian2{K,N}, ::IndexSCartesian2{K,N}) where {K,N} = IndexSCartesian2{K,N}()
193+
194+
struct SCartesianIndex2{K,N} <: AbstractCartesianIndex{N}
195+
i::Int
196+
j::Int
197+
end
198+
to_index(i::SCartesianIndex2) = i
199+
200+
struct SCartesianIndices2{K,N,R<:AbstractUnitRange{Int}} <: AbstractMatrix{SCartesianIndex2{K}}
201+
indices2::R
202+
end
203+
SCartesianIndices2{K,N}(indices2::AbstractUnitRange{Int}) where {K,N} = (@assert K::Int > 1; SCartesianIndices2{K,N,typeof(indices2)}(indices2))
204+
205+
eachindex(::IndexSCartesian2{K,N}, A::AbstractArray) where {K,N} = SCartesianIndices2{K,N}(eachindex(IndexLinear(), parent(A)))
206+
207+
size(iter::SCartesianIndices2{K}) where K = (K, length(iter.indices2))
208+
axes(iter::SCartesianIndices2{K}) where K = (Base.OneTo(K), iter.indices2)
209+
210+
first(iter::SCartesianIndices2{K,N}) where {K,N} = SCartesianIndex2{K,N}(1, first(iter.indices2))
211+
last(iter::SCartesianIndices2{K,N}) where {K,N} = SCartesianIndex2{K,N}(K, last(iter.indices2))
212+
213+
@inline function getindex(iter::SCartesianIndices2{K,N}, i::Int, j::Int) where {K,N}
214+
@boundscheck checkbounds(iter, i, j)
215+
return SCartesianIndex2{K,N}(i, iter.indices2[j])
216+
end
217+
218+
function iterate(iter::SCartesianIndices2{K,N}) where {K,N}
219+
ret = iterate(iter.indices2)
220+
ret === nothing && return nothing
221+
item2, state2 = ret
222+
return SCartesianIndex2{K,N}(1, item2), (1, item2, state2)
223+
end
224+
225+
function iterate(iter::SCartesianIndices2{K,N}, (state1, item2, state2)) where {K,N}
226+
if state1 < K
227+
item1 = state1 + 1
228+
return SCartesianIndex2{K,N}(item1, item2), (item1, item2, state2)
229+
end
230+
ret = iterate(iter.indices2, state2)
231+
ret === nothing && return nothing
232+
item2, state2 = ret
233+
return SCartesianIndex2{K,N}(1, item2), (1, item2, state2)
234+
end
235+
236+
SimdLoop.simd_outer_range(iter::SCartesianIndices2) = iter.indices2
237+
SimdLoop.simd_inner_length(::SCartesianIndices2{K}, ::Any) where K = K
238+
@inline function SimdLoop.simd_index(::SCartesianIndices2{K,N}, Ilast::Int, I1::Int) where {K,N}
239+
SCartesianIndex2{K,N}(I1+1, Ilast)
240+
end
241+
242+
_maybe_reshape(::IndexSCartesian2, A::ReshapedReinterpretArray, I...) = A
243+
244+
# fallbacks
245+
function _getindex(::IndexSCartesian2, A::AbstractArray{T,N}, I::Vararg{Int, N}) where {T,N}
246+
@_propagate_inbounds_meta
247+
getindex(A, I...)
248+
end
249+
function _setindex!(::IndexSCartesian2, A::AbstractArray{T,N}, v, I::Vararg{Int, N}) where {T,N}
250+
@_propagate_inbounds_meta
251+
setindex!(A, v, I...)
252+
end
253+
# fallbacks for array types that use "pass-through" indexing (e.g., `IndexStyle(A) = IndexStyle(parent(A))`)
254+
# but which don't handle SCartesianIndex2
255+
function _getindex(::IndexSCartesian2, A::AbstractArray{T,N}, ind::SCartesianIndex2) where {T,N}
256+
@_propagate_inbounds_meta
257+
I = _to_subscript_indices(A, ind.i, ind.j)
258+
getindex(A, I...)
259+
end
260+
function _setindex!(::IndexSCartesian2, A::AbstractArray{T,N}, v, ind::SCartesianIndex2) where {T,N}
261+
@_propagate_inbounds_meta
262+
I = _to_subscript_indices(A, ind.i, ind.j)
263+
setindex!(A, v, I...)
264+
end
265+
266+
267+
## AbstractArray interface
180268

181269
parent(a::ReinterpretArray) = a.parent
182270
dataids(a::ReinterpretArray) = dataids(a.parent)
@@ -231,6 +319,19 @@ end
231319
_getindex_ra(a, inds[1], tail(inds))
232320
end
233321

322+
@inline @propagate_inbounds function getindex(a::ReshapedReinterpretArray{T,N,S}, ind::SCartesianIndex2) where {T,N,S}
323+
check_readable(a)
324+
n = sizeof(S) ÷ sizeof(T)
325+
t = Ref{NTuple{n,T}}()
326+
s = Ref{S}(a.parent[ind.j])
327+
GC.@preserve t s begin
328+
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
329+
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
330+
_memcpy!(tptr, sptr, sizeof(S))
331+
end
332+
return t[][ind.i]
333+
end
334+
234335
@inline _memcpy!(dst, src, n) = ccall(:memcpy, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Csize_t), dst, src, n)
235336

236337
@inline @propagate_inbounds function _getindex_ra(a::NonReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
@@ -292,9 +393,17 @@ end
292393
if sizeof(T) > sizeof(S)
293394
# Extra dimension in the parent array
294395
n = sizeof(T) ÷ sizeof(S)
295-
for i = 1:n
296-
s[] = a.parent[i, i1, tailinds...]
297-
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
396+
if isempty(tailinds) && IndexStyle(a.parent) === IndexLinear()
397+
offset = n * (i1 - firstindex(a))
398+
for i = 1:n
399+
s[] = a.parent[i + offset]
400+
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
401+
end
402+
else
403+
for i = 1:n
404+
s[] = a.parent[i, i1, tailinds...]
405+
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
406+
end
298407
end
299408
else
300409
# No extra dimension
@@ -334,6 +443,20 @@ end
334443
_setindex_ra!(a, v, inds[1], tail(inds))
335444
end
336445

446+
@inline @propagate_inbounds function setindex!(a::ReshapedReinterpretArray{T,N,S}, v, ind::SCartesianIndex2) where {T,N,S}
447+
check_writable(a)
448+
v = convert(T, v)::T
449+
t = Ref{T}(v)
450+
s = Ref{S}(a.parent[ind.j])
451+
GC.@preserve t s begin
452+
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
453+
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
454+
_memcpy!(sptr + (ind.i-1)*sizeof(T), tptr, sizeof(T))
455+
end
456+
a.parent[ind.j] = s[]
457+
return a
458+
end
459+
337460
@inline @propagate_inbounds function _setindex_ra!(a::NonReshapedReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
338461
v = convert(T, v)::T
339462
# Make sure to match the scalar reinterpret if that is applicable
@@ -407,13 +530,21 @@ end
407530
GC.@preserve t s begin
408531
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
409532
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
410-
if sizeof(T) >= sizeof(S) == 0
533+
if sizeof(T) >= sizeof(S)
411534
if sizeof(T) > sizeof(S)
412535
# Extra dimension in the parent array
413536
n = sizeof(T) ÷ sizeof(S)
414-
for i = 1:n
415-
_memcpy!(sptr, tptr + (i-1)*sizeof(S), sizeof(S))
416-
a.parent[i, i1, tailinds...] = s[]
537+
if isempty(tailinds) && IndexStyle(a.parent) === IndexLinear()
538+
offset = n * (i1 - firstindex(a))
539+
for i = 1:n
540+
_memcpy!(sptr, tptr + (i-1)*sizeof(S), sizeof(S))
541+
a.parent[i + offset] = s[]
542+
end
543+
else
544+
for i = 1:n
545+
_memcpy!(sptr, tptr + (i-1)*sizeof(S), sizeof(S))
546+
a.parent[i, i1, tailinds...] = s[]
547+
end
417548
end
418549
else
419550
# No extra dimension
@@ -523,3 +654,46 @@ using .Iterators: Stateful
523654
end
524655
return true
525656
end
657+
658+
# Reductions with IndexSCartesian2
659+
660+
function _mapreduce(f::F, op::OP, style::IndexSCartesian2{K}, A::AbstractArrayOrBroadcasted) where {F,OP,K}
661+
inds = eachindex(style, A)
662+
n = size(inds)[2]
663+
if n == 0
664+
return mapreduce_empty_iter(f, op, A, IteratorEltype(A))
665+
else
666+
return mapreduce_impl(f, op, A, first(inds), last(inds))
667+
end
668+
end
669+
670+
@noinline function mapreduce_impl(f::F, op::OP, A::AbstractArrayOrBroadcasted,
671+
ifirst::SCI, ilast::SCI, blksize::Int) where {F,OP,SCI<:SCartesianIndex2{K}} where K
672+
if ifirst.j + blksize > ilast.j
673+
# sequential portion
674+
@inbounds a1 = A[ifirst]
675+
@inbounds a2 = A[SCI(2,ifirst.j)]
676+
v = op(f(a1), f(a2))
677+
@simd for i = ifirst.i + 2 : K
678+
@inbounds ai = A[SCI(i,ifirst.j)]
679+
v = op(v, f(ai))
680+
end
681+
# Remaining columns
682+
for j = ifirst.j+1 : ilast.j
683+
@simd for i = 1:K
684+
@inbounds ai = A[SCI(i,j)]
685+
v = op(v, f(ai))
686+
end
687+
end
688+
return v
689+
else
690+
# pairwise portion
691+
jmid = (ifirst.j + ilast.j) >> 1
692+
v1 = mapreduce_impl(f, op, A, ifirst, SCI(K,jmid), blksize)
693+
v2 = mapreduce_impl(f, op, A, SCI(1,jmid+1), ilast, blksize)
694+
return op(v1, v2)
695+
end
696+
end
697+
698+
mapreduce_impl(f::F, op::OP, A::AbstractArrayOrBroadcasted, ifirst::SCartesianIndex2, ilast::SCartesianIndex2) where {F,OP} =
699+
mapreduce_impl(f, op, A, ifirst, ilast, pairwise_blocksize(f, op))

0 commit comments

Comments
 (0)