Skip to content

Commit f03e138

Browse files
mcabbottMichael Abbott
andauthored
add reinterpret(reshape, T, a) (#722)
Co-authored-by: Michael Abbott <me@escbook>
1 parent e3a642c commit f03e138

File tree

4 files changed

+161
-1
lines changed

4 files changed

+161
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Compat"
22
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
3-
version = "3.17.0"
3+
version = "3.18.0"
44

55
[deps]
66
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ changes in `julia`.
5555

5656
## Supported features
5757

58+
* `reinterpret(reshape, T, a::AbstractArray{S})` reinterprets `a` to have eltype `T` while
59+
inserting or consuming the first dimension depending on the ratio of `sizeof(T)` and `sizeof(S)`.
60+
([#37559]). (since Compat 3.18)
61+
5862
* The composition operator `` now returns a `Compat.ComposedFunction` instead of an anonymous function ([#37517]). (since Compat 3.17)
5963

6064
* New function `addenv` for adding environment mappings into a `Cmd` object, returning the new `Cmd` object ([#37244]). (since Compat 3.16)
@@ -203,3 +207,4 @@ Note that you should specify the correct minimum version for `Compat` in the
203207
[#35052]: https://github.com/JuliaLang/julia/pull/35052
204208
[#37244]: https://github.com/JuliaLang/julia/pull/37244
205209
[#37517]: https://github.com/JuliaLang/julia/pull/37517
210+
[#37559]: https://github.com/JuliaLang/julia/pull/37559

src/Compat.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,82 @@ if VERSION < v"1.6.0-DEV.873" # 18198b1bf85125de6cec266eac404d31ccc2e65c
663663
end
664664
end
665665

666+
667+
# https://github.com/JuliaLang/julia/pull/37559
668+
if VERSION < v"1.6.0-DEV.1083"
669+
"""
670+
reinterpret(reshape, T, A::AbstractArray{S}) -> B
671+
672+
Change the type-interpretation of `A` while consuming or adding a "channel dimension."
673+
674+
If `sizeof(T) = n*sizeof(S)` for `n>1`, `A`'s first dimension must be
675+
of size `n` and `B` lacks `A`'s first dimension. Conversely, if `sizeof(S) = n*sizeof(T)` for `n>1`,
676+
`B` gets a new first dimension of size `n`. The dimensionality is unchanged if `sizeof(T) == sizeof(S)`.
677+
678+
# Examples
679+
680+
```
681+
julia> A = [1 2; 3 4]
682+
2×2 Matrix{$Int}:
683+
1 2
684+
3 4
685+
686+
julia> reinterpret(reshape, Complex{Int}, A) # the result is a vector
687+
2-element reinterpret(reshape, Complex{$Int}, ::Matrix{$Int}):
688+
1 + 3im
689+
2 + 4im
690+
691+
julia> a = [(1,2,3), (4,5,6)]
692+
2-element Vector{Tuple{$Int, $Int, $Int}}:
693+
(1, 2, 3)
694+
(4, 5, 6)
695+
696+
julia> reinterpret(reshape, Int, a) # the result is a matrix
697+
3×2 reinterpret(reshape, $Int, ::Vector{Tuple{$Int, $Int, $Int}}):
698+
1 4
699+
2 5
700+
3 6
701+
```
702+
"""
703+
function Base.reinterpret(::typeof(reshape), ::Type{T}, a::A) where {T,S,A<:AbstractArray{S}}
704+
isbitstype(T) || throwbits(S, T, T)
705+
isbitstype(S) || throwbits(S, T, S)
706+
if sizeof(S) == sizeof(T)
707+
N = ndims(a)
708+
elseif sizeof(S) > sizeof(T)
709+
rem(sizeof(S), sizeof(T)) == 0 || throwintmult(S, T)
710+
N = ndims(a) + 1
711+
else
712+
rem(sizeof(T), sizeof(S)) == 0 || throwintmult(S, T)
713+
N = ndims(a) - 1
714+
N > -1 || throwsize0(S, T, "larger")
715+
axes(a, 1) == Base.OneTo(sizeof(T) ÷ sizeof(S)) || throwsize1(a, T)
716+
end
717+
paxs = axes(a)
718+
new_axes = if sizeof(S) > sizeof(T)
719+
(Base.OneTo(div(sizeof(S), sizeof(T))), paxs...)
720+
elseif sizeof(S) < sizeof(T)
721+
Base.tail(paxs)
722+
else
723+
paxs
724+
end
725+
reshape(reinterpret(T, vec(a)), new_axes)
726+
end
727+
728+
@noinline function throwintmult(S::Type, T::Type)
729+
throw(ArgumentError("`reinterpret(reshape, T, a)` requires that one of `sizeof(T)` (got $(sizeof(T))) and `sizeof(eltype(a))` (got $(sizeof(S))) be an integer multiple of the other"))
730+
end
731+
@noinline function throwsize1(a::AbstractArray, T::Type)
732+
throw(ArgumentError("`reinterpret(reshape, $T, a)` where `eltype(a)` is $(eltype(a)) requires that `axes(a, 1)` (got $(axes(a, 1))) be equal to 1:$(sizeof(T) ÷ sizeof(eltype(a))) (from the ratio of element sizes)"))
733+
end
734+
@noinline function throwbits(S::Type, T::Type, U::Type)
735+
throw(ArgumentError("cannot reinterpret `$(S)` as `$(T)`, type `$(U)` is not a bits type"))
736+
end
737+
@noinline function throwsize0(S::Type, T::Type, msg)
738+
throw(ArgumentError("cannot reinterpret a zero-dimensional `$(S)` array to `$(T)` which is of a $msg size"))
739+
end
740+
end
741+
666742
include("iterators.jl")
667743
include("deprecated.jl")
668744

test/runtests.jl

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,85 @@ end
600600
@test strip(String(read(cmd))) == "baz bar"
601601
end
602602

603+
# https://github.com/JuliaLang/julia/pull/37559
604+
@testset "reinterpred(reshape, ...)" begin
605+
# simplified from PR
606+
Ar = Int64[1 3; 2 4]
607+
@test @inferred(ndims(reinterpret(reshape, Complex{Int64}, Ar))) == 1
608+
@test @inferred(axes(reinterpret(reshape, Complex{Int64}, Ar))) === (Base.OneTo(2),)
609+
@test @inferred(size(reinterpret(reshape, Complex{Int64}, Ar))) == (2,)
610+
611+
_B = Complex{Int64}[5+6im, 7+8im, 9+10im]
612+
@test @inferred(ndims(reinterpret(reshape, Int64, _B))) == 2
613+
@test @inferred(axes(reinterpret(reshape, Int64, _B))) === (Base.OneTo(2), Base.OneTo(3))
614+
@test @inferred(size(reinterpret(reshape, Int64, _B))) == (2, 3)
615+
@test @inferred(ndims(reinterpret(reshape, Int128, _B))) == 1
616+
@test @inferred(axes(reinterpret(reshape, Int128, _B))) === (Base.OneTo(3),)
617+
@test @inferred(size(reinterpret(reshape, Int128, _B))) == (3,)
618+
619+
A = Int64[1, 2, 3, 4]
620+
Av = [Int32[1,2], Int32[3,4]]
621+
622+
@test_throws ArgumentError reinterpret(Vector{Int64}, A) # ("cannot reinterpret `Int64` as `Vector{Int64}`, type `Vector{Int64}` is not a bits type")
623+
@test_throws ArgumentError reinterpret(Int32, Av) # ("cannot reinterpret `Vector{Int32}` as `Int32`, type `Vector{Int32}` is not a bits type")
624+
@test_throws ArgumentError("cannot reinterpret a zero-dimensional `Int64` array to `Int32` which is of a different size") reinterpret(Int32, reshape([Int64(0)]))
625+
@test_throws ArgumentError("cannot reinterpret a zero-dimensional `Int32` array to `Int64` which is of a different size") reinterpret(Int64, reshape([Int32(0)]))
626+
@test_throws ArgumentError reinterpret(Tuple{Int,Int}, [1,2,3,4,5]) # ("""cannot reinterpret an `$Int` array to `Tuple{$Int, $Int}` whose first dimension has size `5`.
627+
# The resulting array would have non-integral first dimension.
628+
# """)
629+
@test_throws ArgumentError("`reinterpret(reshape, Complex{Int64}, a)` where `eltype(a)` is Int64 requires that `axes(a, 1)` (got Base.OneTo(4)) be equal to 1:2 (from the ratio of element sizes)") reinterpret(reshape, Complex{Int64}, A)
630+
@test_throws ArgumentError("`reinterpret(reshape, T, a)` requires that one of `sizeof(T)` (got 24) and `sizeof(eltype(a))` (got 16) be an integer multiple of the other") reinterpret(reshape, NTuple{3, Int64}, _B)
631+
@test_throws ArgumentError reinterpret(reshape, Vector{Int64}, Ar) # ("cannot reinterpret `Int64` as `Vector{Int64}`, type `Vector{Int64}` is not a bits type")
632+
@test_throws ArgumentError("cannot reinterpret a zero-dimensional `UInt8` array to `UInt16` which is of a larger size") reinterpret(reshape, UInt16, reshape([0x01]))
633+
634+
# getindex
635+
_A = A
636+
@test reinterpret(Complex{Int64}, _A) == [1 + 2im, 3 + 4im]
637+
@test reinterpret(Float64, _A) == reinterpret.(Float64, A)
638+
@test reinterpret(reshape, Float64, _A) == reinterpret.(Float64, A)
639+
640+
Ars = Ar
641+
@test reinterpret(reshape, Complex{Int64}, Ar) == [1 + 2im, 3 + 4im]
642+
@test reinterpret(reshape, Float64, Ar) == reinterpret.(Float64, Ars)
643+
644+
# setindex
645+
A3 = collect(reshape(1:18, 2, 3, 3))
646+
A3r = reinterpret(reshape, Complex{Int}, A3)
647+
@test A3r[4] === A3r[1,2] === A3r[CartesianIndex(1, 2)] === 7+8im
648+
A3r[2,3] = -8-15im
649+
@test A3[1,2,3] == -8
650+
@test A3[2,2,3] == -15
651+
A3r[4] = 100+200im
652+
@test A3[1,1,2] == 100
653+
@test A3[2,1,2] == 200
654+
A3r[CartesianIndex(1,2)] = 300+400im
655+
@test A3[1,1,2] == 300
656+
@test A3[2,1,2] == 400
657+
658+
# Test 0-dimensional Arrays
659+
A = zeros(UInt32)
660+
B = reinterpret(Int32,A)
661+
Brs = reinterpret(reshape,Int32,A)
662+
@test size(B) == size(Brs) == ()
663+
@test axes(B) == axes(Brs) == ()
664+
B[] = Int32(5)
665+
@test B[] === Int32(5)
666+
@test Brs[] === Int32(5)
667+
@test A[] === UInt32(5)
668+
669+
# reductions
670+
a = [(1,2,3), (4,5,6)]
671+
ars = reinterpret(reshape, Int, a)
672+
@test sum(ars) == 21
673+
@test sum(ars; dims=1) == [6 15]
674+
@test sum(ars; dims=2) == reshape([5,7,9], (3, 1))
675+
@test sum(ars; dims=(1,2)) == reshape([21], (1, 1))
676+
# also test large sizes for the pairwise algorithm
677+
a = [(k,k+1,k+2) for k = 1:3:4000]
678+
ars = reinterpret(reshape, Int, a)
679+
@test sum(ars) == 8010003
680+
end
681+
603682
include("iterators.jl")
604683

605684
nothing

0 commit comments

Comments
 (0)