Skip to content

Commit dd0c14b

Browse files
authored
Fix getindex and setindex! on 0-dimensional reinterpretarray (#43819)
1 parent 75a1d0f commit dd0c14b

File tree

2 files changed

+53
-26
lines changed

2 files changed

+53
-26
lines changed

base/reinterpretarray.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,15 @@ axes(a::NonReshapedReinterpretArray{T,0}) where {T} = ()
328328
elsize(::Type{<:ReinterpretArray{T}}) where {T} = sizeof(T)
329329
unsafe_convert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} = Ptr{T}(unsafe_convert(Ptr{S},a.parent))
330330

331-
@inline @propagate_inbounds getindex(a::NonReshapedReinterpretArray{T,0}) where {T} = reinterpret(T, a.parent[])
332-
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[1]
331+
@inline @propagate_inbounds function getindex(a::NonReshapedReinterpretArray{T,0,S}) where {T,S}
332+
if isprimitivetype(T) && isprimitivetype(S)
333+
reinterpret(T, a.parent[])
334+
else
335+
a[firstindex(a)]
336+
end
337+
end
338+
339+
@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[firstindex(a)]
333340

334341
@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
335342
check_readable(a)
@@ -462,8 +469,15 @@ end
462469
return t[][i1]
463470
end
464471

465-
@inline @propagate_inbounds setindex!(a::NonReshapedReinterpretArray{T,0,S} where T, v) where {S} = (a.parent[] = reinterpret(S, v))
466-
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = (a[1] = v)
472+
@inline @propagate_inbounds function setindex!(a::NonReshapedReinterpretArray{T,0,S}, v) where {T,S}
473+
if isprimitivetype(S) && isprimitivetype(T)
474+
a.parent[] = reinterpret(S, v)
475+
return a
476+
end
477+
setindex!(a, v, firstindex(a))
478+
end
479+
480+
@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = setindex!(a, v, firstindex(a))
467481

468482
@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
469483
check_writable(a)

test/reinterpretarray.jl

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -67,29 +67,29 @@ for (_A, Ar, _B) in ((A, Ars, B), (As, Arss, Bs))
6767
reinterpret(NTuple{3, Int64}, Bc)[2] = (4,5,6)
6868
@test Bc == Complex{Int64}[5+6im, 7+4im, 5+6im]
6969
B2 = reinterpret(NTuple{3, Int64}, Bc)
70-
@test setindex!(B2, (1,2,3), 1) == B2
70+
@test setindex!(B2, (1,2,3), 1) === B2
7171
@test Bc == Complex{Int64}[1+2im, 3+4im, 5+6im]
7272
Bc = copy(_B)
7373
Brrs = reinterpret(reshape, Int64, Bc)
74-
@test setindex!(Brrs, -5, 2, 3) == Brrs
74+
@test setindex!(Brrs, -5, 2, 3) === Brrs
7575
@test Bc == Complex{Int64}[5+6im, 7+8im, 9-5im]
7676
Brrs[last(eachindex(Brrs))] = 22
7777
@test Bc == Complex{Int64}[5+6im, 7+8im, 9+22im]
7878

7979
A1 = reinterpret(Float64, _A)
8080
A2 = reinterpret(ComplexF64, _A)
81-
@test setindex!(A1, 1.0, 1) == A1
81+
@test setindex!(A1, 1.0, 1) === A1
8282
@test real(A2[1]) == 1.0
8383
A1 = reinterpret(reshape, Float64, _A)
84-
A1[1] = 2.5
84+
@test setindex!(A1, 2.5, 1) === A1
8585
@test reinterpret(Float64, _A[1]) == 2.5
8686
A1rs = reinterpret(Float64, Ar)
8787
A2rs = reinterpret(ComplexF64, Ar)
88-
A1rs[1, 1] = 1.0
88+
@test setindex!(A1rs, 1.0, 1, 1) === A1rs
8989
@test real(A2rs[1]) == 1.0
9090
A1rs = reinterpret(reshape, Float64, Ar)
9191
A2rs = reinterpret(reshape, ComplexF64, Ar)
92-
@test setindex!(A1rs, 2.5, 1, 1) == A1rs
92+
@test setindex!(A1rs, 2.5, 1, 1) === A1rs
9393
@test real(A2rs[1]) == 2.5
9494
end
9595
end
@@ -107,14 +107,14 @@ A3r[CartesianIndex(1,2)] = 300+400im
107107
@test A3[2,1,2] == 400
108108

109109
# same-size reinterpret where one of the types is non-primitive
110-
let a = NTuple{4,UInt8}[(0x01,0x02,0x03,0x04)]
111-
@test reinterpret(Float32, a)[1] == reinterpret(Float32, 0x04030201)
112-
reinterpret(Float32, a)[1] = 2.0
110+
let a = NTuple{4,UInt8}[(0x01,0x02,0x03,0x04)], ra = reinterpret(Float32, a)
111+
@test ra[1] == reinterpret(Float32, 0x04030201)
112+
@test setindex!(ra, 2.0) === ra
113113
@test reinterpret(Float32, a)[1] == 2.0
114114
end
115-
let a = NTuple{4,UInt8}[(0x01,0x02,0x03,0x04)]
116-
@test reinterpret(reshape, Float32, a)[1] == reinterpret(Float32, 0x04030201)
117-
reinterpret(reshape, Float32, a)[1] = 2.0
115+
let a = NTuple{4,UInt8}[(0x01,0x02,0x03,0x04)], ra = reinterpret(reshape, Float32, a)
116+
@test ra[1] == reinterpret(Float32, 0x04030201)
117+
@test setindex!(ra, 2.0) === ra
118118
@test reinterpret(reshape, Float32, a)[1] == 2.0
119119
end
120120

@@ -198,7 +198,7 @@ let a = fill(1.0, 5, 3)
198198
@test_throws BoundsError r[badinds...] = -2
199199
end
200200
for goodinds in (1, 15, (1,1), (5,3))
201-
r[goodinds...] = -2
201+
@test setindex!(r, -2, goodinds...) === r
202202
@test r[goodinds...] == -2
203203
end
204204
r = reinterpret(Int32, a)
@@ -211,7 +211,7 @@ let a = fill(1.0, 5, 3)
211211
@test_throws BoundsError r[badinds...] = -3
212212
end
213213
for goodinds in (1, 30, (1,1), (10,3))
214-
r[goodinds...] = -3
214+
@test setindex!(r, -3, goodinds...) === r
215215
@test r[goodinds...] == -3
216216
end
217217
r = reinterpret(Int64, view(a, 1:2:5, :))
@@ -224,7 +224,7 @@ let a = fill(1.0, 5, 3)
224224
@test_throws BoundsError r[badinds...] = -4
225225
end
226226
for goodinds in (1, 9, (1,1), (3,3))
227-
r[goodinds...] = -4
227+
@test setindex!(r, -4, goodinds...) === r
228228
@test r[goodinds...] == -4
229229
end
230230
r = reinterpret(Int32, view(a, 1:2:5, :))
@@ -237,7 +237,7 @@ let a = fill(1.0, 5, 3)
237237
@test_throws BoundsError r[badinds...] = -5
238238
end
239239
for goodinds in (1, 18, (1,1), (6,3))
240-
r[goodinds...] = -5
240+
@test setindex!(r, -5, goodinds...) === r
241241
@test r[goodinds...] == -5
242242
end
243243

@@ -318,14 +318,25 @@ end
318318

319319
# Test 0-dimensional Arrays
320320
A = zeros(UInt32)
321-
B = reinterpret(Int32,A)
322-
Brs = reinterpret(reshape,Int32,A)
323-
@test size(B) == size(Brs) == ()
324-
@test axes(B) == axes(Brs) == ()
325-
B[] = Int32(5)
321+
B = reinterpret(Int32, A)
322+
Brs = reinterpret(reshape,Int32, A)
323+
C = reinterpret(Tuple{UInt32}, A) # non-primitive type
324+
Crs = reinterpret(reshape, Tuple{UInt32}, A) # non-primitive type
325+
@test size(B) == size(Brs) == size(C) == size(Crs) == ()
326+
@test axes(B) == axes(Brs) == axes(C) == axes(Crs) == ()
327+
@test setindex!(B, Int32(5)) === B
326328
@test B[] === Int32(5)
327329
@test Brs[] === Int32(5)
330+
@test C[] === (UInt32(5),)
331+
@test Crs[] === (UInt32(5),)
328332
@test A[] === UInt32(5)
333+
@test setindex!(Brs, Int32(12)) === Brs
334+
@test A[] === UInt32(12)
335+
@test setindex!(C, (UInt32(7),)) === C
336+
@test A[] === UInt32(7)
337+
@test setindex!(Crs, (UInt32(3),)) === Crs
338+
@test A[] === UInt32(3)
339+
329340

330341
a = [(1.0,2.0)]
331342
af = @inferred(reinterpret(reshape, Float64, a))
@@ -413,13 +424,15 @@ end
413424
z = reinterpret(Tuple{}, fill(missing, ()))
414425
@test z == fill((), ())
415426
@test z == reinterpret(reshape, Tuple{}, fill(nothing, ()))
427+
@test z[] == ()
428+
@test setindex!(z, ()) === z
416429
@test_throws BoundsError z[2]
417430
@test_throws BoundsError z[3] = ()
418431
@test_throws ArgumentError reinterpret(UInt8, fill(nothing, ()))
419432
@test_throws ArgumentError reinterpret(Missing, fill(1f0, ()))
420433
@test_throws ArgumentError reinterpret(reshape, Float64, fill(nothing, ()))
421434
@test_throws ArgumentError reinterpret(reshape, Nothing, fill(17, ()))
422-
435+
@test_throws MethodError z[] = nothing
423436

424437
@test @inferred(ndims(reinterpret(reshape, SomeSingleton, t))) == 2
425438
@test @inferred(axes(reinterpret(reshape, Tuple{}, t))) == (Base.OneTo(3),Base.OneTo(5))

0 commit comments

Comments
 (0)