@@ -978,3 +978,144 @@ end
978
978
@test Core. sizeof (arrayOfUInt48) == 24
979
979
end
980
980
end
981
+
982
+ struct Strider{T,N} <: AbstractArray{T,N}
983
+ data:: Vector{T}
984
+ offset:: Int
985
+ strides:: NTuple{N,Int}
986
+ size:: NTuple{N,Int}
987
+ end
988
+ function Strider {T} (strides:: NTuple{N} , size:: NTuple{N} ) where {T,N}
989
+ offset = 1 - sum (strides .* (strides .< 0 ) .* (size .- 1 ))
990
+ data = Array {T} (undef, sum (abs .(strides) .* (size .- 1 )) + 1 )
991
+ return Strider {T, N, Vector{T}} (data, offset, strides, size)
992
+ end
993
+ function Strider (vec:: AbstractArray{T} , strides:: NTuple{N} , size:: NTuple{N} ) where {T,N}
994
+ offset = 1 - sum (strides .* (strides .< 0 ) .* (size .- 1 ))
995
+ @assert length (vec) >= sum (abs .(strides) .* (size .- 1 )) + 1
996
+ return Strider {T, N} (vec, offset, strides, size)
997
+ end
998
+ Base. size (S:: Strider ) = S. size
999
+ function Base. getindex (S:: Strider{<:Any,N} , I:: Vararg{Int,N} ) where {N}
1000
+ return S. data[sum (S. strides .* (I .- 1 )) + S. offset]
1001
+ end
1002
+ Base. strides (S:: Strider ) = S. strides
1003
+ Base. elsize (:: Type{<:Strider{T}} ) where {T} = Base. elsize (Vector{T})
1004
+ Base. unsafe_convert (:: Type{Ptr{T}} , S:: Strider{T} ) where {T} = pointer (S. data, S. offset)
1005
+
1006
+ @testset " Simple 3d strided views and permutes" for sz in ((5 , 3 , 2 ), (7 , 11 , 13 ))
1007
+ A = collect (reshape (1 : prod (sz), sz))
1008
+ S = Strider (vec (A), strides (A), sz)
1009
+ @test pointer (A) == pointer (S)
1010
+ for i in 1 : prod (sz)
1011
+ @test pointer (A, i) == pointer (S, i)
1012
+ @test A[i] == S[i]
1013
+ end
1014
+ for idxs in ((1 : sz[1 ], 1 : sz[2 ], 1 : sz[3 ]),
1015
+ (1 : sz[1 ], 2 : 2 : sz[2 ], sz[3 ]: - 1 : 1 ),
1016
+ (2 : 2 : sz[1 ]- 1 , sz[2 ]: - 1 : 1 , sz[3 ]: - 2 : 2 ),
1017
+ (sz[1 ]: - 1 : 1 , sz[2 ]: - 1 : 1 , sz[3 ]: - 1 : 1 ),
1018
+ (sz[1 ]- 1 : - 3 : 1 , sz[2 ]: - 2 : 3 , 1 : sz[3 ]),)
1019
+ Ai = A[idxs... ]
1020
+ Av = view (A, idxs... )
1021
+ Sv = view (S, idxs... )
1022
+ Ss = Strider {Int, 3} (vec (A), sum ((first .(idxs).- 1 ). * strides (A))+ 1 , strides (Av), length .(idxs))
1023
+ @test pointer (Av) == pointer (Sv) == pointer (Ss)
1024
+ for i in 1 : length (Av)
1025
+ @test pointer (Av, i) == pointer (Sv, i) == pointer (Ss, i)
1026
+ @test Ai[i] == Av[i] == Sv[i] == Ss[i]
1027
+ end
1028
+ for perm in ((3 , 2 , 1 ), (2 , 1 , 3 ), (3 , 1 , 2 ))
1029
+ P = permutedims (A, perm)
1030
+ Ap = Base. PermutedDimsArray (A, perm)
1031
+ Sp = Base. PermutedDimsArray (S, perm)
1032
+ Ps = Strider {Int, 3} (vec (A), 1 , strides (A)[collect (perm)], sz[collect (perm)])
1033
+ @test pointer (Ap) == pointer (Sp) == pointer (Ps)
1034
+ for i in 1 : length (Ap)
1035
+ # This is intentionally disabled due to ambiguity
1036
+ @test_broken pointer (Ap, i) == pointer (Sp, i) == pointer (Ps, i)
1037
+ @test P[i] == Ap[i] == Sp[i] == Ps[i]
1038
+ end
1039
+ Pv = view (P, idxs[collect (perm)]. .. )
1040
+ Pi = P[idxs[collect (perm)]. .. ]
1041
+ Apv = view (Ap, idxs[collect (perm)]. .. )
1042
+ Spv = view (Sp, idxs[collect (perm)]. .. )
1043
+ Pvs = Strider {Int, 3} (vec (A), sum ((first .(idxs).- 1 ). * strides (A))+ 1 , strides (Apv), size (Apv))
1044
+ @test pointer (Apv) == pointer (Spv) == pointer (Pvs)
1045
+ for i in 1 : length (Apv)
1046
+ @test pointer (Apv, i) == pointer (Spv, i) == pointer (Pvs, i)
1047
+ @test Pi[i] == Pv[i] == Apv[i] == Spv[i] == Pvs[i]
1048
+ end
1049
+ Vp = permutedims (Av, perm)
1050
+ Ip = permutedims (Ai, perm)
1051
+ Avp = Base. PermutedDimsArray (Av, perm)
1052
+ Svp = Base. PermutedDimsArray (Sv, perm)
1053
+ @test pointer (Avp) == pointer (Svp)
1054
+ for i in 1 : length (Avp)
1055
+ # This is intentionally disabled due to ambiguity
1056
+ @test_broken pointer (Avp, i) == pointer (Svp, i)
1057
+ @test Ip[i] == Vp[i] == Avp[i] == Svp[i]
1058
+ end
1059
+ end
1060
+ end
1061
+ end
1062
+
1063
+ @testset " simple 2d strided views, permutes, transposes" for sz in ((5 , 3 ), (7 , 11 ))
1064
+ A = collect (reshape (1 : prod (sz), sz))
1065
+ S = Strider (vec (A), strides (A), sz)
1066
+ @test pointer (A) == pointer (S)
1067
+ for i in 1 : prod (sz)
1068
+ @test pointer (A, i) == pointer (S, i)
1069
+ @test A[i] == S[i]
1070
+ end
1071
+ for idxs in ((1 : sz[1 ], 1 : sz[2 ]),
1072
+ (1 : sz[1 ], 2 : 2 : sz[2 ]),
1073
+ (2 : 2 : sz[1 ]- 1 , sz[2 ]: - 1 : 1 ),
1074
+ (sz[1 ]: - 1 : 1 , sz[2 ]: - 1 : 1 ),
1075
+ (sz[1 ]- 1 : - 3 : 1 , sz[2 ]: - 2 : 3 ),)
1076
+ Av = view (A, idxs... )
1077
+ Sv = view (S, idxs... )
1078
+ Ss = Strider {Int, 2} (vec (A), sum ((first .(idxs).- 1 ). * strides (A))+ 1 , strides (Av), length .(idxs))
1079
+ @test pointer (Av) == pointer (Sv) == pointer (Ss)
1080
+ for i in 1 : length (Av)
1081
+ @test pointer (Av, i) == pointer (Sv, i) == pointer (Ss, i)
1082
+ @test Av[i] == Sv[i] == Ss[i]
1083
+ end
1084
+ perm = (2 , 1 )
1085
+ P = permutedims (A, perm)
1086
+ Ap = Base. PermutedDimsArray (A, perm)
1087
+ At = transpose (A)
1088
+ Aa = adjoint (A)
1089
+ Sp = Base. PermutedDimsArray (S, perm)
1090
+ Ps = Strider {Int, 2} (vec (A), 1 , strides (A)[collect (perm)], sz[collect (perm)])
1091
+ @test pointer (Ap) == pointer (Sp) == pointer (Ps) == pointer (At) == pointer (Aa)
1092
+ for i in 1 : length (Ap)
1093
+ # This is intentionally disabled due to ambiguity
1094
+ @test_broken pointer (Ap, i) == pointer (Sp, i) == pointer (Ps, i) == pointer (At, i) == pointer (Aa, i)
1095
+ @test pointer (Ps, i) == pointer (At, i) == pointer (Aa, i)
1096
+ @test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i]
1097
+ end
1098
+ Pv = view (P, idxs[collect (perm)]. .. )
1099
+ Apv = view (Ap, idxs[collect (perm)]. .. )
1100
+ Atv = view (At, idxs[collect (perm)]. .. )
1101
+ Ata = view (Aa, idxs[collect (perm)]. .. )
1102
+ Spv = view (Sp, idxs[collect (perm)]. .. )
1103
+ Pvs = Strider {Int, 2} (vec (A), sum ((first .(idxs).- 1 ). * strides (A))+ 1 , strides (Apv), size (Apv))
1104
+ @test pointer (Apv) == pointer (Spv) == pointer (Pvs) == pointer (Atv) == pointer (Ata)
1105
+ for i in 1 : length (Apv)
1106
+ @test pointer (Apv, i) == pointer (Spv, i) == pointer (Pvs, i) == pointer (Atv, i) == pointer (Ata, i)
1107
+ @test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i]
1108
+ end
1109
+ Vp = permutedims (Av, perm)
1110
+ Avp = Base. PermutedDimsArray (Av, perm)
1111
+ Avt = transpose (Av)
1112
+ Ava = adjoint (Av)
1113
+ Svp = Base. PermutedDimsArray (Sv, perm)
1114
+ @test pointer (Avp) == pointer (Svp) == pointer (Avt) == pointer (Ava)
1115
+ for i in 1 : length (Avp)
1116
+ # This is intentionally disabled due to ambiguity
1117
+ @test_broken pointer (Avp, i) == pointer (Svp, i) == pointer (Avt, i) == pointer (Ava, i)
1118
+ @test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i]
1119
+ end
1120
+ end
1121
+ end
0 commit comments