Skip to content

Commit 7b1cc4b

Browse files
authored
Allow reinterpreting singleton types (#43500)
1 parent 14154fc commit 7b1cc4b

File tree

2 files changed

+110
-9
lines changed

2 files changed

+110
-9
lines changed

base/reinterpretarray.jl

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T
1919
@noinline
2020
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a $msg size"))
2121
end
22+
function throwsingleton(S::Type, T::Type, kind)
23+
@noinline
24+
throw(ArgumentError("cannot reinterpret $kind `$(S)` array to `$(T)` which is a singleton type"))
25+
end
2226

2327
global reinterpret
2428
function reinterpret(::Type{T}, a::A) where {T,N,S,A<:AbstractArray{S, N}}
@@ -39,7 +43,11 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T
3943
if N != 0 && sizeof(S) != sizeof(T)
4044
ax1 = axes(a)[1]
4145
dim = length(ax1)
42-
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
46+
if Base.issingletontype(T)
47+
dim == 0 || throwsingleton(S, T, "a non-empty")
48+
else
49+
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
50+
end
4351
first(ax1) == 1 || throwaxes1(S, T, ax1)
4452
end
4553
readable = array_subpadding(T, S)
@@ -58,14 +66,20 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T
5866
@noinline
5967
throw(ArgumentError("`reinterpret(reshape, $T, a)` where `eltype(a)` is $(eltype(a)) requires that `axes(a, 1)` (got $(axes(a, 1))) be equal to 1:$(sizeof(T) ÷ sizeof(eltype(a))) (from the ratio of element sizes)"))
6068
end
69+
function throwfromsingleton(S, T)
70+
@noinline
71+
throw(ArgumentError("`reinterpret(reshape, $T, a)` where `eltype(a)` is $S requires that $T be a singleton type, since $S is one"))
72+
end
6173
isbitstype(T) || throwbits(S, T, T)
6274
isbitstype(S) || throwbits(S, T, S)
6375
if sizeof(S) == sizeof(T)
6476
N = ndims(a)
6577
elseif sizeof(S) > sizeof(T)
78+
Base.issingletontype(T) && throwsingleton(S, T, "with reshape a")
6679
rem(sizeof(S), sizeof(T)) == 0 || throwintmult(S, T)
6780
N = ndims(a) + 1
6881
else
82+
Base.issingletontype(S) && throwfromsingleton(S, T)
6983
rem(sizeof(T), sizeof(S)) == 0 || throwintmult(S, T)
7084
N = ndims(a) - 1
7185
N > -1 || throwsize0(S, T, "larger")
@@ -286,7 +300,7 @@ unaliascopy(a::ReshapedReinterpretArray{T}) where {T} = reinterpret(reshape, T,
286300

287301
function size(a::NonReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
288302
psize = size(a.parent)
289-
size1 = div(psize[1]*sizeof(S), sizeof(T))
303+
size1 = Base.issingletontype(T) ? psize[1] : div(psize[1]*sizeof(S), sizeof(T))
290304
tuple(size1, tail(psize)...)
291305
end
292306
function size(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
@@ -300,7 +314,7 @@ size(a::NonReshapedReinterpretArray{T,0}) where {T} = ()
300314
function axes(a::NonReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
301315
paxs = axes(a.parent)
302316
f, l = first(paxs[1]), length(paxs[1])
303-
size1 = div(l*sizeof(S), sizeof(T))
317+
size1 = Base.issingletontype(T) ? l : div(l*sizeof(S), sizeof(T))
304318
tuple(oftype(paxs[1], f:f+size1-1), tail(paxs)...)
305319
end
306320
function axes(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
@@ -351,6 +365,10 @@ end
351365
@inline @propagate_inbounds function _getindex_ra(a::NonReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
352366
# Make sure to match the scalar reinterpret if that is applicable
353367
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
368+
if Base.issingletontype(T) # singleton types
369+
@boundscheck checkbounds(a, i1, tailinds...)
370+
return T.instance
371+
end
354372
return reinterpret(T, a.parent[i1, tailinds...])
355373
else
356374
@boundscheck checkbounds(a, i1, tailinds...)
@@ -395,6 +413,10 @@ end
395413
@inline @propagate_inbounds function _getindex_ra(a::ReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
396414
# Make sure to match the scalar reinterpret if that is applicable
397415
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
416+
if Base.issingletontype(T) # singleton types
417+
@boundscheck checkbounds(a, i1, tailinds...)
418+
return T.instance
419+
end
398420
return reinterpret(T, a.parent[i1, tailinds...])
399421
end
400422
@boundscheck checkbounds(a, i1, tailinds...)
@@ -475,7 +497,12 @@ end
475497
v = convert(T, v)::T
476498
# Make sure to match the scalar reinterpret if that is applicable
477499
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
478-
return setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
500+
if Base.issingletontype(T) # singleton types
501+
@boundscheck checkbounds(a, i1, tailinds...)
502+
# setindex! is a noop except for the index check
503+
else
504+
setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
505+
end
479506
else
480507
@boundscheck checkbounds(a, i1, tailinds...)
481508
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
@@ -536,7 +563,12 @@ end
536563
v = convert(T, v)::T
537564
# Make sure to match the scalar reinterpret if that is applicable
538565
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
539-
return setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
566+
if Base.issingletontype(T) # singleton types
567+
@boundscheck checkbounds(a, i1, tailinds...)
568+
# setindex! is a noop except for the index check
569+
else
570+
setindex!(a.parent, reinterpret(S, v), i1, tailinds...)
571+
end
540572
end
541573
@boundscheck checkbounds(a, i1, tailinds...)
542574
t = Ref{T}(v)

test/reinterpretarray.jl

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,19 @@ for (_A, Ar, _B) in ((A, Ars, B), (As, Arss, Bs))
6666
@test Arsc == [1 -1; 2 -2]
6767
reinterpret(NTuple{3, Int64}, Bc)[2] = (4,5,6)
6868
@test Bc == Complex{Int64}[5+6im, 7+4im, 5+6im]
69-
reinterpret(NTuple{3, Int64}, Bc)[1] = (1,2,3)
69+
B2 = reinterpret(NTuple{3, Int64}, Bc)
70+
@test setindex!(B2, (1,2,3), 1) == B2
7071
@test Bc == Complex{Int64}[1+2im, 3+4im, 5+6im]
7172
Bc = copy(_B)
7273
Brrs = reinterpret(reshape, Int64, Bc)
73-
Brrs[2, 3] = -5
74+
@test setindex!(Brrs, -5, 2, 3) == Brrs
7475
@test Bc == Complex{Int64}[5+6im, 7+8im, 9-5im]
7576
Brrs[last(eachindex(Brrs))] = 22
7677
@test Bc == Complex{Int64}[5+6im, 7+8im, 9+22im]
7778

7879
A1 = reinterpret(Float64, _A)
7980
A2 = reinterpret(ComplexF64, _A)
80-
A1[1] = 1.0
81+
@test setindex!(A1, 1.0, 1) == A1
8182
@test real(A2[1]) == 1.0
8283
A1 = reinterpret(reshape, Float64, _A)
8384
A1[1] = 2.5
@@ -88,7 +89,7 @@ for (_A, Ar, _B) in ((A, Ars, B), (As, Arss, Bs))
8889
@test real(A2rs[1]) == 1.0
8990
A1rs = reinterpret(reshape, Float64, Ar)
9091
A2rs = reinterpret(reshape, ComplexF64, Ar)
91-
A1rs[1, 1] = 2.5
92+
@test setindex!(A1rs, 2.5, 1, 1) == A1rs
9293
@test real(A2rs[1]) == 2.5
9394
end
9495
end
@@ -376,3 +377,71 @@ end
376377
a = reinterpret(reshape, NTuple{4,Float64}, rand(Float64, 4, 4))
377378
@test typeof(Base.unaliascopy(a)) === typeof(a)
378379
end
380+
381+
382+
@testset "singleton types" begin
383+
mutable struct NotASingleton end # not a singleton because it is mutable
384+
struct SomeSingleton
385+
# A singleton type that does not have the internal constructor SomeSingleton()
386+
SomeSingleton(x) = new()
387+
end
388+
389+
@test_throws ErrorException reinterpret(Int, nothing)
390+
@test_throws ErrorException reinterpret(Missing, 3)
391+
@test_throws ErrorException reinterpret(Missing, NotASingleton())
392+
@test_throws ErrorException reinterpret(NotASingleton, ())
393+
394+
@test_throws ArgumentError reinterpret(NotASingleton, fill(nothing, ()))
395+
@test_throws ArgumentError reinterpret(reshape, NotASingleton, fill(missing, 3))
396+
@test_throws ArgumentError reinterpret(Tuple{}, fill(NotASingleton(), 2))
397+
@test_throws ArgumentError reinterpret(reshape, Nothing, fill(NotASingleton(), ()))
398+
399+
t = fill(nothing, 3, 5)
400+
@test reinterpret(SomeSingleton, t) == reinterpret(reshape, SomeSingleton, t)
401+
@test reinterpret(SomeSingleton, t) == [SomeSingleton(i*j) for i in 1:3, j in 1:5]
402+
@test reinterpret(Int, t) == fill(17, 0, 5)
403+
@test_throws ArgumentError reinterpret(reshape, Float64, t)
404+
@test_throws ArgumentError reinterpret(Nothing, 1:6)
405+
@test_throws ArgumentError reinterpret(reshape, Missing, [0.0])
406+
407+
# reintepret of empty array with reshape
408+
@test reinterpret(reshape, Nothing, fill(missing, (0,0,0))) == fill(nothing, (0,0,0))
409+
@test_throws ArgumentError reinterpret(reshape, Nothing, fill(3.2, (0,0)))
410+
@test_throws ArgumentError reinterpret(reshape, Float64, fill(nothing, 0))
411+
412+
# reinterpret of 0-dimensional array
413+
z = reinterpret(Tuple{}, fill(missing, ()))
414+
@test z == fill((), ())
415+
@test z == reinterpret(reshape, Tuple{}, fill(nothing, ()))
416+
@test_throws BoundsError z[2]
417+
@test_throws BoundsError z[3] = ()
418+
@test_throws ArgumentError reinterpret(UInt8, fill(nothing, ()))
419+
@test_throws ArgumentError reinterpret(Missing, fill(1f0, ()))
420+
@test_throws ArgumentError reinterpret(reshape, Float64, fill(nothing, ()))
421+
@test_throws ArgumentError reinterpret(reshape, Nothing, fill(17, ()))
422+
423+
424+
@test @inferred(ndims(reinterpret(reshape, SomeSingleton, t))) == 2
425+
@test @inferred(axes(reinterpret(reshape, Tuple{}, t))) == (Base.OneTo(3),Base.OneTo(5))
426+
@test @inferred(size(reinterpret(reshape, Missing, t))) == (3,5)
427+
428+
x = reinterpret(Tuple{}, t)
429+
@test x == reinterpret(reshape, Tuple{}, t)
430+
@test x[3,5] === ()
431+
x1 = fill((), 3, 5)
432+
@test setindex!(x, (), 1, 1) == x1
433+
@test_throws BoundsError x[17]
434+
@test_throws BoundsError x[4,2]
435+
@test_throws BoundsError x[1,2,3]
436+
@test_throws BoundsError x[18] = ()
437+
@test_throws MethodError x[1,3] = missing
438+
@test x == fill((), (3, 5))
439+
x = reinterpret(reshape, SomeSingleton, t)
440+
@test_throws BoundsError x[19]
441+
@test_throws BoundsError x[2,6] = SomeSingleton(0xa)
442+
@test x[2,3] === SomeSingleton(:x)
443+
x2 = fill(SomeSingleton(0.7), 3, 5)
444+
@test x == x2
445+
@test setindex!(x, SomeSingleton(:), 3, 5) == x2
446+
@test_throws MethodError x[2,4] = nothing
447+
end

0 commit comments

Comments
 (0)