Skip to content

Commit bc2abbe

Browse files
committed
Add reinterpret(reshape, T, a)
This addresses longstanding performance problems with `reinterpret` when `sizeof(eltype(a))` is an integer multiple of `sizeof(T)`. By reshaping the array to have an extra "channel dimension," LLVM can unroll the inner loop thanks to static size information. Conversely, this consumes the initial "channel dimension" if `sizeof(T)` is an integer multiple of `sizeof(eltype(a))`.
1 parent 4f0145b commit bc2abbe

File tree

4 files changed

+237
-27
lines changed

4 files changed

+237
-27
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ New library features
8383

8484
* The `redirect_*` functions can now be called on `IOContext` objects.
8585
* New constructor `NamedTuple(iterator)` that constructs a named tuple from a key-value pair iterator.
86+
* A new `reinterpret(reshape, T, a::AbstractArray{S})` reinterprets `a` to have eltype `T` while potentially
87+
inserting or consuming the first dimension depending on the ratio of `sizeof(T)` and `sizeof(S)`.
8688

8789
Standard library changes
8890
------------------------

base/reinterpretarray.jl

Lines changed: 180 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,39 @@
33
"""
44
Gives a reinterpreted view (of element type T) of the underlying array (of element type S).
55
If the size of `T` differs from the size of `S`, the array will be compressed/expanded in
6-
the first dimension.
6+
the first dimension. The variant `reinterpret(reshape, T, a)` instead adds or consumes the first dimension
7+
depending on the ratio of element sizes.
78
"""
8-
struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
9+
struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T, N}
910
parent::A
1011
readable::Bool
1112
writable::Bool
13+
14+
function throwbits(S::Type, T::Type, U::Type)
15+
@_noinline_meta
16+
throw(ArgumentError("cannot reinterpret `$(S)` as `$(T)`, type `$(U)` is not a bits type"))
17+
end
18+
function throwsize0(S::Type, T::Type, msg)
19+
@_noinline_meta
20+
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a $msg size"))
21+
end
22+
1223
global reinterpret
1324
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
14-
function throwbits(::Type{S}, ::Type{T}, ::Type{U}) where {S,T,U}
15-
@_noinline_meta
16-
throw(ArgumentError("cannot reinterpret `$(S)` `$(T)`, type `$(U)` is not a bits type"))
17-
end
18-
function throwsize0(::Type{S}, ::Type{T})
19-
@_noinline_meta
20-
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a different size"))
21-
end
22-
function thrownonint(::Type{S}, ::Type{T}, dim)
25+
function thrownonint(S::Type, T::Type, dim)
2326
@_noinline_meta
2427
throw(ArgumentError("""
2528
cannot reinterpret an `$(S)` array to `$(T)` whose first dimension has size `$(dim)`.
2629
The resulting array would have non-integral first dimension.
2730
"""))
2831
end
29-
function throwaxes1(::Type{S}, ::Type{T}, ax1)
32+
function throwaxes1(S::Type, T::Type, ax1)
3033
@_noinline_meta
3134
throw(ArgumentError("cannot reinterpret a `$(S)` array to `$(T)` when the first axis is $ax1. Try reshaping first."))
3235
end
3336
isbitstype(T) || throwbits(S, T, T)
3437
isbitstype(S) || throwbits(S, T, S)
35-
(N != 0 || sizeof(T) == sizeof(S)) || throwsize0(S, T)
38+
(N != 0 || sizeof(T) == sizeof(S)) || throwsize0(S, T, "different")
3639
if N != 0 && sizeof(S) != sizeof(T)
3740
ax1 = axes(a)[1]
3841
dim = length(ax1)
@@ -41,15 +44,82 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S, N}} <: AbstractArray{T, N}
4144
end
4245
readable = array_subpadding(T, S)
4346
writable = array_subpadding(S, T)
44-
new{T, N, S, A}(a, readable, writable)
47+
new{T, N, S, A, false}(a, readable, writable)
48+
end
49+
50+
# With reshaping
51+
function reinterpret(::typeof(reshape), ::Type{T}, a::A) where {T,S,A<:AbstractArray{S}}
52+
function throwintmult(S::Type, T::Type)
53+
@_noinline_meta
54+
throw(ArgumentError("`reinterpret(reshape, T, a)` requires that one of `sizeof(T)` (got $(sizeof(T))) and `sizeof(eltype(a))` (got $(sizeof(S))) be an integer multiple of the other"))
55+
end
56+
function throwsize1(a::AbstractArray, T::Type)
57+
@_noinline_meta
58+
throw(ArgumentError("`reinterpret(reshape, $T, a)` where `eltype(a)` is $(eltype(a)) requires that `axes(a, 1)` (got $(axes(a, 1))) be equal to 1:$(sizeof(T) ÷ sizeof(eltype(a))) (from the ratio of element sizes)"))
59+
end
60+
isbitstype(T) || throwbits(S, T, T)
61+
isbitstype(S) || throwbits(S, T, S)
62+
if sizeof(S) == sizeof(T)
63+
N = ndims(a)
64+
elseif sizeof(S) > sizeof(T)
65+
rem(sizeof(S), sizeof(T)) == 0 || throwintmult(S, T)
66+
N = ndims(a) + 1
67+
else
68+
rem(sizeof(T), sizeof(S)) == 0 || throwintmult(S, T)
69+
N = ndims(a) - 1
70+
N > -1 || throwsize0(S, T, "larger")
71+
axes(a, 1) == Base.OneTo(sizeof(T) ÷ sizeof(S)) || throwsize1(a, T)
72+
end
73+
readable = array_subpadding(T, S)
74+
writable = array_subpadding(S, T)
75+
new{T, N, S, A, true}(a, readable, writable)
4576
end
4677
end
4778

79+
ReshapedReinterpretArray{T,N,S,A<:AbstractArray{S}} = ReinterpretArray{T,N,S,A,true}
80+
NonReshapedReinterpretArray{T,N,S,A<:AbstractArray{S, N}} = ReinterpretArray{T,N,S,A,false}
81+
82+
"""
83+
reinterpret(reshape, T, A::AbstractArray{S}) -> B
84+
85+
Change the type-interpretation of `A` while consuming or adding a "channel dimension."
86+
87+
If `sizeof(T) = n*sizeof(S)` for `n>1`, `A`'s first dimension must be
88+
of size `n` and `B` lacks `A`'s first dimension. Conversely, if `sizeof(S) = n*sizeof(T)` for `n>1`,
89+
`B` gets a new first dimension of size `n`. The dimensionality is unchanged if `sizeof(T) == sizeof(S)`.
90+
91+
# Examples
92+
93+
```jldoctest
94+
julia> A = [1 2; 3 4]
95+
2×2 Matrix{$Int}:
96+
1 2
97+
3 4
98+
99+
julia> reinterpret(reshape, Complex{Int}, A) # the result is a vector
100+
2-element reinterpret(reshape, Complex{$Int}, ::Matrix{$Int}):
101+
1 + 3im
102+
2 + 4im
103+
104+
julia> a = [(1,2,3), (4,5,6)]
105+
2-element Vector{Tuple{$Int, $Int, $Int}}:
106+
(1, 2, 3)
107+
(4, 5, 6)
108+
109+
julia> reinterpret(reshape, Int, a) # the result is a matrix
110+
3×2 reinterpret(reshape, $Int, ::Vector{Tuple{$Int, $Int, $Int}}):
111+
1 4
112+
2 5
113+
3 6
114+
```
115+
"""
116+
reinterpret(::typeof(reshape), T::Type, a::AbstractArray)
117+
48118
reinterpret(::Type{T}, a::ReinterpretArray) where {T} = reinterpret(T, a.parent)
49119

50120
# Definition of StridedArray
51121
StridedFastContiguousSubArray{T,N,A<:DenseArray} = FastContiguousSubArray{T,N,A}
52-
StridedReinterpretArray{T,N,A<:Union{DenseArray,StridedFastContiguousSubArray}} = ReinterpretArray{T,N,S,A} where S
122+
StridedReinterpretArray{T,N,A<:Union{DenseArray,StridedFastContiguousSubArray},IsReshaped} = ReinterpretArray{T,N,S,A,IsReshaped} where S
53123
StridedReshapedArray{T,N,A<:Union{DenseArray,StridedFastContiguousSubArray,StridedReinterpretArray}} = ReshapedArray{T,N,A}
54124
StridedSubArray{T,N,A<:Union{DenseArray,StridedReshapedArray,StridedReinterpretArray},
55125
I<:Tuple{Vararg{Union{RangeIndex, ReshapedUnitRange, AbstractCartesianIndex}}}} = SubArray{T,N,A,I}
@@ -106,30 +176,43 @@ function check_writable(a::ReinterpretArray{T, N, S} where N) where {T,S}
106176
end
107177

108178
IndexStyle(a::ReinterpretArray) = IndexStyle(a.parent)
179+
IndexStyle(a::ReshapedReinterpretArray{T, N, S}) where {T, N, S} = sizeof(T) < sizeof(S) ? IndexCartesian() : IndexStyle(a.parent)
109180

110181
parent(a::ReinterpretArray) = a.parent
111182
dataids(a::ReinterpretArray) = dataids(a.parent)
112183
unaliascopy(a::ReinterpretArray{T}) where {T} = reinterpret(T, unaliascopy(a.parent))
113184

114-
function size(a::ReinterpretArray{T,N,S} where {N}) where {T,S}
185+
function size(a::NonReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
115186
psize = size(a.parent)
116187
size1 = div(psize[1]*sizeof(S), sizeof(T))
117188
tuple(size1, tail(psize)...)
118189
end
119-
size(a::ReinterpretArray{T,0}) where {T} = ()
190+
function size(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
191+
psize = size(a.parent)
192+
sizeof(S) > sizeof(T) && return (div(sizeof(S), sizeof(T)), psize...)
193+
sizeof(S) < sizeof(T) && return Base.tail(psize)
194+
return psize
195+
end
196+
size(a::NonReshapedReinterpretArray{T,0}) where {T} = ()
120197

121-
function axes(a::ReinterpretArray{T,N,S} where {N}) where {T,S}
198+
function axes(a::NonReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
122199
paxs = axes(a.parent)
123200
f, l = first(paxs[1]), length(paxs[1])
124201
size1 = div(l*sizeof(S), sizeof(T))
125202
tuple(oftype(paxs[1], f:f+size1-1), tail(paxs)...)
126203
end
127-
axes(a::ReinterpretArray{T,0}) where {T} = ()
204+
function axes(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
205+
paxs = axes(a.parent)
206+
sizeof(S) > sizeof(T) && return (Base.OneTo(div(sizeof(S), sizeof(T))), paxs...)
207+
sizeof(S) < sizeof(T) && return Base.tail(paxs)
208+
return paxs
209+
end
210+
axes(a::NonReshapedReinterpretArray{T,0}) where {T} = ()
128211

129212
elsize(::Type{<:ReinterpretArray{T}}) where {T} = sizeof(T)
130213
unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} = Ptr{T}(unsafe_convert(Ptr{S},a.parent))
131214

132-
@inline @propagate_inbounds getindex(a::ReinterpretArray{T,0}) where {T} = reinterpret(T, a.parent[])
215+
@inline @propagate_inbounds getindex(a::NonReshapedReinterpretArray{T,0}) where {T} = reinterpret(T, a.parent[])
133216
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1]
134217

135218
@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
@@ -150,7 +233,7 @@ end
150233

151234
@inline _memcpy!(dst, src, n) = ccall(:memcpy, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Csize_t), dst, src, n)
152235

153-
@inline @propagate_inbounds function _getindex_ra(a::ReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
236+
@inline @propagate_inbounds function _getindex_ra(a::NonReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
154237
# Make sure to match the scalar reinterpret if that is applicable
155238
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
156239
return reinterpret(T, a.parent[i1, tailinds...])
@@ -194,8 +277,47 @@ end
194277
end
195278
end
196279

280+
@inline @propagate_inbounds function _getindex_ra(a::ReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
281+
# Make sure to match the scalar reinterpret if that is applicable
282+
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
283+
return reinterpret(T, a.parent[i1, tailinds...])
284+
end
285+
@boundscheck checkbounds(a, i1, tailinds...)
286+
if sizeof(T) >= sizeof(S)
287+
t = Ref{T}()
288+
s = Ref{S}()
289+
GC.@preserve t s begin
290+
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
291+
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
292+
if sizeof(T) > sizeof(S)
293+
# Extra dimension in the parent array
294+
n = sizeof(T) ÷ sizeof(S)
295+
for i = 1:n
296+
s[] = a.parent[i, i1, tailinds...]
297+
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
298+
end
299+
else
300+
# No extra dimension
301+
s[] = a.parent[i1, tailinds...]
302+
_memcpy!(tptr, sptr, sizeof(S))
303+
end
304+
end
305+
return t[]
306+
end
307+
# S is bigger than T and contains an integer number of them
308+
n = sizeof(S) ÷ sizeof(T)
309+
t = Ref{NTuple{n,T}}()
310+
s = Ref{S}()
311+
GC.@preserve t s begin
312+
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
313+
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
314+
s[] = a.parent[tailinds...]
315+
_memcpy!(tptr, sptr, sizeof(S))
316+
end
317+
return t[][i1]
318+
end
197319

198-
@inline @propagate_inbounds setindex!(a::ReinterpretArray{T,0,S} where T, v) where {S} = (a.parent[] = reinterpret(S, v))
320+
@inline @propagate_inbounds setindex!(a::NonReshapedReinterpretArray{T,0,S} where T, v) where {S} = (a.parent[] = reinterpret(S, v))
199321
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v)
200322

201323
@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
@@ -212,7 +334,7 @@ end
212334
_setindex_ra!(a, v, inds[1], tail(inds))
213335
end
214336

215-
@inline @propagate_inbounds function _setindex_ra!(a::ReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
337+
@inline @propagate_inbounds function _setindex_ra!(a::NonReshapedReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
216338
v = convert(T, v)::T
217339
# Make sure to match the scalar reinterpret if that is applicable
218340
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
@@ -273,6 +395,41 @@ end
273395
return a
274396
end
275397

398+
@inline @propagate_inbounds function _setindex_ra!(a::ReshapedReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
399+
v = convert(T, v)::T
400+
# Make sure to match the scalar reinterpret if that is applicable
401+
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
402+
return setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
403+
end
404+
@boundscheck checkbounds(a, i1, tailinds...)
405+
t = Ref{T}(v)
406+
s = Ref{S}()
407+
GC.@preserve t s begin
408+
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
409+
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
410+
if sizeof(T) >= sizeof(S) == 0
411+
if sizeof(T) > sizeof(S)
412+
# Extra dimension in the parent array
413+
n = sizeof(T) ÷ sizeof(S)
414+
for i = 1:n
415+
_memcpy!(sptr, tptr + (i-1)*sizeof(S), sizeof(S))
416+
a.parent[i, i1, tailinds...] = s[]
417+
end
418+
else
419+
# No extra dimension
420+
_memcpy!(sptr, tptr, sizeof(S))
421+
a.parent[i1, tailinds...] = s[]
422+
end
423+
else
424+
# S is bigger than T and contains an integer number of them
425+
s[] = a.parent[tailinds...]
426+
_memcpy!(sptr + (i1-1)*sizeof(T), tptr, sizeof(T))
427+
a.parent[tailinds...] = s[]
428+
end
429+
end
430+
return a
431+
end
432+
276433
# Padding
277434
struct Padding
278435
offset::Int

base/show.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2562,12 +2562,18 @@ function showarg(io::IO, r::ReshapedArray, toplevel)
25622562
toplevel && print(io, " with eltype ", eltype(r))
25632563
end
25642564

2565-
function showarg(io::IO, r::ReinterpretArray{T}, toplevel) where {T}
2565+
function showarg(io::IO, r::NonReshapedReinterpretArray{T}, toplevel) where {T}
25662566
print(io, "reinterpret(", T, ", ")
25672567
showarg(io, parent(r), false)
25682568
print(io, ')')
25692569
end
25702570

2571+
function showarg(io::IO, r::ReshapedReinterpretArray{T}, toplevel) where {T}
2572+
print(io, "reinterpret(reshape, ", T, ", ")
2573+
showarg(io, parent(r), false)
2574+
print(io, ')')
2575+
end
2576+
25712577
# printing iterators from Base.Iterators
25722578

25732579
function show(io::IO, e::Iterators.Enumerate)

0 commit comments

Comments
 (0)