@@ -175,8 +175,96 @@ function check_writable(a::ReinterpretArray{T, N, S} where N) where {T,S}
175
175
end
176
176
end
177
177
178
- IndexStyle (a:: ReinterpretArray ) = IndexStyle (a. parent)
179
- IndexStyle (a:: ReshapedReinterpretArray{T, N, S} ) where {T, N, S} = sizeof (T) < sizeof (S) ? IndexCartesian () : IndexStyle (a. parent)
178
+ # # IndexStyle specializations
179
+
180
+ # For `reinterpret(reshape, T, a)` where we're adding a channel dimension and with
181
+ # `IndexStyle(a) == IndexLinear()`, it's advantageous to retain pseudo-linear indexing.
182
+ struct IndexSCartesian2{K,N} <: IndexStyle end # K = sizeof(S) ÷ sizeof(T), a static-sized 2d cartesian iterator
183
+
184
+ IndexStyle (:: Type{ReinterpretArray{T,N,S,A,false}} ) where {T,N,S,A<: AbstractArray{S,N} } = IndexStyle (A)
185
+ function IndexStyle (:: Type{ReinterpretArray{T,N,S,A,true}} ) where {T,N,S,A<: AbstractArray{S} }
186
+ if sizeof (T) < sizeof (S)
187
+ IndexStyle (A) === IndexLinear () && return IndexSCartesian2 {sizeof(S) ÷ sizeof(T),N} ()
188
+ return IndexCartesian ()
189
+ end
190
+ return IndexStyle (A)
191
+ end
192
+ IndexStyle (:: IndexSCartesian2{K,N} , :: IndexSCartesian2{K,N} ) where {K,N} = IndexSCartesian2 {K,N} ()
193
+
194
+ struct SCartesianIndex2{K,N} <: AbstractCartesianIndex{N}
195
+ i:: Int
196
+ j:: Int
197
+ end
198
+ to_index (i:: SCartesianIndex2 ) = i
199
+
200
+ struct SCartesianIndices2{K,N,R<: AbstractUnitRange{Int} } <: AbstractMatrix{SCartesianIndex2{K}}
201
+ indices2:: R
202
+ end
203
+ SCartesianIndices2 {K,N} (indices2:: AbstractUnitRange{Int} ) where {K,N} = (@assert K:: Int > 1 ; SCartesianIndices2 {K,N,typeof(indices2)} (indices2))
204
+
205
+ eachindex (:: IndexSCartesian2{K,N} , A:: AbstractArray ) where {K,N} = SCartesianIndices2 {K,N} (eachindex (IndexLinear (), parent (A)))
206
+
207
+ size (iter:: SCartesianIndices2{K} ) where K = (K, length (iter. indices2))
208
+ axes (iter:: SCartesianIndices2{K} ) where K = (Base. OneTo (K), iter. indices2)
209
+
210
+ first (iter:: SCartesianIndices2{K,N} ) where {K,N} = SCartesianIndex2 {K,N} (1 , first (iter. indices2))
211
+ last (iter:: SCartesianIndices2{K,N} ) where {K,N} = SCartesianIndex2 {K,N} (K, last (iter. indices2))
212
+
213
+ @inline function getindex (iter:: SCartesianIndices2{K,N} , i:: Int , j:: Int ) where {K,N}
214
+ @boundscheck checkbounds (iter, i, j)
215
+ return SCartesianIndex2 {K,N} (i, iter. indices2[j])
216
+ end
217
+
218
+ function iterate (iter:: SCartesianIndices2{K,N} ) where {K,N}
219
+ ret = iterate (iter. indices2)
220
+ ret === nothing && return nothing
221
+ item2, state2 = ret
222
+ return SCartesianIndex2 {K,N} (1 , item2), (1 , item2, state2)
223
+ end
224
+
225
+ function iterate (iter:: SCartesianIndices2{K,N} , (state1, item2, state2)) where {K,N}
226
+ if state1 < K
227
+ item1 = state1 + 1
228
+ return SCartesianIndex2 {K,N} (item1, item2), (item1, item2, state2)
229
+ end
230
+ ret = iterate (iter. indices2, state2)
231
+ ret === nothing && return nothing
232
+ item2, state2 = ret
233
+ return SCartesianIndex2 {K,N} (1 , item2), (1 , item2, state2)
234
+ end
235
+
236
+ SimdLoop. simd_outer_range (iter:: SCartesianIndices2 ) = iter. indices2
237
+ SimdLoop. simd_inner_length (:: SCartesianIndices2{K} , :: Any ) where K = K
238
+ @inline function SimdLoop. simd_index (:: SCartesianIndices2{K,N} , Ilast:: Int , I1:: Int ) where {K,N}
239
+ SCartesianIndex2 {K,N} (I1+ 1 , Ilast)
240
+ end
241
+
242
+ _maybe_reshape (:: IndexSCartesian2 , A:: ReshapedReinterpretArray , I... ) = A
243
+
244
+ # fallbacks
245
+ function _getindex (:: IndexSCartesian2 , A:: AbstractArray{T,N} , I:: Vararg{Int, N} ) where {T,N}
246
+ @_propagate_inbounds_meta
247
+ getindex (A, I... )
248
+ end
249
+ function _setindex! (:: IndexSCartesian2 , A:: AbstractArray{T,N} , v, I:: Vararg{Int, N} ) where {T,N}
250
+ @_propagate_inbounds_meta
251
+ setindex! (A, v, I... )
252
+ end
253
+ # fallbacks for array types that use "pass-through" indexing (e.g., `IndexStyle(A) = IndexStyle(parent(A))`)
254
+ # but which don't handle SCartesianIndex2
255
+ function _getindex (:: IndexSCartesian2 , A:: AbstractArray{T,N} , ind:: SCartesianIndex2 ) where {T,N}
256
+ @_propagate_inbounds_meta
257
+ I = _to_subscript_indices (A, ind. i, ind. j)
258
+ getindex (A, I... )
259
+ end
260
+ function _setindex! (:: IndexSCartesian2 , A:: AbstractArray{T,N} , v, ind:: SCartesianIndex2 ) where {T,N}
261
+ @_propagate_inbounds_meta
262
+ I = _to_subscript_indices (A, ind. i, ind. j)
263
+ setindex! (A, v, I... )
264
+ end
265
+
266
+
267
+ # # AbstractArray interface
180
268
181
269
parent (a:: ReinterpretArray ) = a. parent
182
270
dataids (a:: ReinterpretArray ) = dataids (a. parent)
231
319
_getindex_ra (a, inds[1 ], tail (inds))
232
320
end
233
321
322
+ @inline @propagate_inbounds function getindex (a:: ReshapedReinterpretArray{T,N,S} , ind:: SCartesianIndex2 ) where {T,N,S}
323
+ check_readable (a)
324
+ n = sizeof (S) ÷ sizeof (T)
325
+ t = Ref {NTuple{n,T}} ()
326
+ s = Ref {S} (a. parent[ind. j])
327
+ GC. @preserve t s begin
328
+ tptr = Ptr {UInt8} (unsafe_convert (Ref{T}, t))
329
+ sptr = Ptr {UInt8} (unsafe_convert (Ref{S}, s))
330
+ _memcpy! (tptr, sptr, sizeof (S))
331
+ end
332
+ return t[][ind. i]
333
+ end
334
+
234
335
@inline _memcpy! (dst, src, n) = ccall (:memcpy , Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Csize_t), dst, src, n)
235
336
236
337
@inline @propagate_inbounds function _getindex_ra (a:: NonReshapedReinterpretArray{T,N,S} , i1:: Int , tailinds:: TT ) where {T,N,S,TT}
292
393
if sizeof (T) > sizeof (S)
293
394
# Extra dimension in the parent array
294
395
n = sizeof (T) ÷ sizeof (S)
295
- for i = 1 : n
296
- s[] = a. parent[i, i1, tailinds... ]
297
- _memcpy! (tptr + (i- 1 )* sizeof (S), sptr, sizeof (S))
396
+ if isempty (tailinds) && IndexStyle (a. parent) === IndexLinear ()
397
+ offset = n * (i1 - firstindex (a))
398
+ for i = 1 : n
399
+ s[] = a. parent[i + offset]
400
+ _memcpy! (tptr + (i- 1 )* sizeof (S), sptr, sizeof (S))
401
+ end
402
+ else
403
+ for i = 1 : n
404
+ s[] = a. parent[i, i1, tailinds... ]
405
+ _memcpy! (tptr + (i- 1 )* sizeof (S), sptr, sizeof (S))
406
+ end
298
407
end
299
408
else
300
409
# No extra dimension
334
443
_setindex_ra! (a, v, inds[1 ], tail (inds))
335
444
end
336
445
446
+ @inline @propagate_inbounds function setindex! (a:: ReshapedReinterpretArray{T,N,S} , v, ind:: SCartesianIndex2 ) where {T,N,S}
447
+ check_writable (a)
448
+ v = convert (T, v):: T
449
+ t = Ref {T} (v)
450
+ s = Ref {S} (a. parent[ind. j])
451
+ GC. @preserve t s begin
452
+ tptr = Ptr {UInt8} (unsafe_convert (Ref{T}, t))
453
+ sptr = Ptr {UInt8} (unsafe_convert (Ref{S}, s))
454
+ _memcpy! (sptr + (ind. i- 1 )* sizeof (T), tptr, sizeof (T))
455
+ end
456
+ a. parent[ind. j] = s[]
457
+ return a
458
+ end
459
+
337
460
@inline @propagate_inbounds function _setindex_ra! (a:: NonReshapedReinterpretArray{T,N,S} , v, i1:: Int , tailinds:: TT ) where {T,N,S,TT}
338
461
v = convert (T, v):: T
339
462
# Make sure to match the scalar reinterpret if that is applicable
@@ -407,13 +530,21 @@ end
407
530
GC. @preserve t s begin
408
531
tptr = Ptr {UInt8} (unsafe_convert (Ref{T}, t))
409
532
sptr = Ptr {UInt8} (unsafe_convert (Ref{S}, s))
410
- if sizeof (T) >= sizeof (S) == 0
533
+ if sizeof (T) >= sizeof (S)
411
534
if sizeof (T) > sizeof (S)
412
535
# Extra dimension in the parent array
413
536
n = sizeof (T) ÷ sizeof (S)
414
- for i = 1 : n
415
- _memcpy! (sptr, tptr + (i- 1 )* sizeof (S), sizeof (S))
416
- a. parent[i, i1, tailinds... ] = s[]
537
+ if isempty (tailinds) && IndexStyle (a. parent) === IndexLinear ()
538
+ offset = n * (i1 - firstindex (a))
539
+ for i = 1 : n
540
+ _memcpy! (sptr, tptr + (i- 1 )* sizeof (S), sizeof (S))
541
+ a. parent[i + offset] = s[]
542
+ end
543
+ else
544
+ for i = 1 : n
545
+ _memcpy! (sptr, tptr + (i- 1 )* sizeof (S), sizeof (S))
546
+ a. parent[i, i1, tailinds... ] = s[]
547
+ end
417
548
end
418
549
else
419
550
# No extra dimension
@@ -523,3 +654,46 @@ using .Iterators: Stateful
523
654
end
524
655
return true
525
656
end
657
+
658
+ # Reductions with IndexSCartesian2
659
+
660
+ function _mapreduce (f:: F , op:: OP , style:: IndexSCartesian2{K} , A:: AbstractArrayOrBroadcasted ) where {F,OP,K}
661
+ inds = eachindex (style, A)
662
+ n = size (inds)[2 ]
663
+ if n == 0
664
+ return mapreduce_empty_iter (f, op, A, IteratorEltype (A))
665
+ else
666
+ return mapreduce_impl (f, op, A, first (inds), last (inds))
667
+ end
668
+ end
669
+
670
+ @noinline function mapreduce_impl (f:: F , op:: OP , A:: AbstractArrayOrBroadcasted ,
671
+ ifirst:: SCI , ilast:: SCI , blksize:: Int ) where {F,OP,SCI<: SCartesianIndex2{K} } where K
672
+ if ifirst. j + blksize > ilast. j
673
+ # sequential portion
674
+ @inbounds a1 = A[ifirst]
675
+ @inbounds a2 = A[SCI (2 ,ifirst. j)]
676
+ v = op (f (a1), f (a2))
677
+ @simd for i = ifirst. i + 2 : K
678
+ @inbounds ai = A[SCI (i,ifirst. j)]
679
+ v = op (v, f (ai))
680
+ end
681
+ # Remaining columns
682
+ for j = ifirst. j+ 1 : ilast. j
683
+ @simd for i = 1 : K
684
+ @inbounds ai = A[SCI (i,j)]
685
+ v = op (v, f (ai))
686
+ end
687
+ end
688
+ return v
689
+ else
690
+ # pairwise portion
691
+ jmid = (ifirst. j + ilast. j) >> 1
692
+ v1 = mapreduce_impl (f, op, A, ifirst, SCI (K,jmid), blksize)
693
+ v2 = mapreduce_impl (f, op, A, SCI (1 ,jmid+ 1 ), ilast, blksize)
694
+ return op (v1, v2)
695
+ end
696
+ end
697
+
698
+ mapreduce_impl (f:: F , op:: OP , A:: AbstractArrayOrBroadcasted , ifirst:: SCartesianIndex2 , ilast:: SCartesianIndex2 ) where {F,OP} =
699
+ mapreduce_impl (f, op, A, ifirst, ilast, pairwise_blocksize (f, op))
0 commit comments