Skip to content

Commit 59b8dde

Browse files
authored
Fix pointer to no longer assume contiguity (#36405)
* Fix pointer to no longer assume contiguity
1 parent 29e1454 commit 59b8dde

File tree

5 files changed

+163
-22
lines changed

5 files changed

+163
-22
lines changed

base/abstractarray.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,14 @@ end
10071007
pointer(x::AbstractArray{T}) where {T} = unsafe_convert(Ptr{T}, x)
10081008
function pointer(x::AbstractArray{T}, i::Integer) where T
10091009
@_inline_meta
1010-
unsafe_convert(Ptr{T}, x) + (i - first(LinearIndices(x)))*elsize(x)
1010+
unsafe_convert(Ptr{T}, x) + _memory_offset(x, i)
1011+
end
1012+
1013+
# The distance from pointer(x) to the element at x[I...] in bytes
1014+
_memory_offset(x::DenseArray, I...) = (_to_linear_index(x, I...) - first(LinearIndices(x)))*elsize(x)
1015+
function _memory_offset(x::AbstractArray, I...)
1016+
J = _to_subscript_indices(x, I...)
1017+
return sum(map((i, s, o)->s*(i-o), J, strides(x), Tuple(first(CartesianIndices(x)))))*elsize(x)
10111018
end
10121019

10131020
## Approach:
@@ -1078,10 +1085,10 @@ function _getindex(::IndexLinear, A::AbstractArray, I::Vararg{Int,M}) where M
10781085
@inbounds r = getindex(A, _to_linear_index(A, I...))
10791086
r
10801087
end
1081-
_to_linear_index(A::AbstractArray, i::Int) = i
1082-
_to_linear_index(A::AbstractVector, i::Int, I::Int...) = i
1088+
_to_linear_index(A::AbstractArray, i::Integer) = i
1089+
_to_linear_index(A::AbstractVector, i::Integer, I::Integer...) = i
10831090
_to_linear_index(A::AbstractArray) = 1
1084-
_to_linear_index(A::AbstractArray, I::Int...) = (@_inline_meta; _sub2ind(A, I...))
1091+
_to_linear_index(A::AbstractArray, I::Integer...) = (@_inline_meta; _sub2ind(A, I...))
10851092

10861093
## IndexCartesian Scalar indexing: Canonical method is full dimensionality of Ints
10871094
function _getindex(::IndexCartesian, A::AbstractArray, I::Vararg{Int,M}) where M
@@ -1094,12 +1101,12 @@ function _getindex(::IndexCartesian, A::AbstractArray{T,N}, I::Vararg{Int, N}) w
10941101
@_propagate_inbounds_meta
10951102
getindex(A, I...)
10961103
end
1097-
_to_subscript_indices(A::AbstractArray, i::Int) = (@_inline_meta; _unsafe_ind2sub(A, i))
1104+
_to_subscript_indices(A::AbstractArray, i::Integer) = (@_inline_meta; _unsafe_ind2sub(A, i))
10981105
_to_subscript_indices(A::AbstractArray{T,N}) where {T,N} = (@_inline_meta; fill_to_length((), 1, Val(N)))
10991106
_to_subscript_indices(A::AbstractArray{T,0}) where {T} = ()
1100-
_to_subscript_indices(A::AbstractArray{T,0}, i::Int) where {T} = ()
1101-
_to_subscript_indices(A::AbstractArray{T,0}, I::Int...) where {T} = ()
1102-
function _to_subscript_indices(A::AbstractArray{T,N}, I::Int...) where {T,N}
1107+
_to_subscript_indices(A::AbstractArray{T,0}, i::Integer) where {T} = ()
1108+
_to_subscript_indices(A::AbstractArray{T,0}, I::Integer...) where {T} = ()
1109+
function _to_subscript_indices(A::AbstractArray{T,N}, I::Integer...) where {T,N}
11031110
@_inline_meta
11041111
J, Jrem = IteratorsMD.split(I, Val(N))
11051112
_to_subscript_indices(A, J, Jrem)

base/permuteddimsarray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ function Base.strides(A::PermutedDimsArray{T,N,perm}) where {T,N,perm}
6464
s = strides(parent(A))
6565
ntuple(d->s[perm[d]], Val(N))
6666
end
67+
Base.elsize(::Type{<:PermutedDimsArray{<:Any, <:Any, <:Any, <:Any, P}}) where {P} = Base.elsize(P)
6768

6869
@inline function Base.getindex(A::PermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}) where {T,N,perm,iperm}
6970
@boundscheck checkbounds(A, I...)

base/subarray.jl

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -398,23 +398,12 @@ find_extended_inds(::ScalarIndex, I...) = (@_inline_meta; find_extended_inds(I..
398398
find_extended_inds(i1, I...) = (@_inline_meta; (i1, find_extended_inds(I...)...))
399399
find_extended_inds() = ()
400400

401-
unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{RangeIndex}}}) where {T,N,P} =
402-
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)
401+
function unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{RangeIndex}}}) where {T,N,P}
402+
return unsafe_convert(Ptr{T}, V.parent) + _memory_offset(V.parent, map(first, V.indices)...)
403+
end
403404

404405
pointer(V::FastSubArray, i::Int) = pointer(V.parent, V.offset1 + V.stride1*i)
405406
pointer(V::FastContiguousSubArray, i::Int) = pointer(V.parent, V.offset1 + i)
406-
pointer(V::SubArray, i::Int) = _pointer(V, i)
407-
_pointer(V::SubArray{<:Any,1}, i::Int) = pointer(V, (i,))
408-
_pointer(V::SubArray, i::Int) = pointer(V, Base._ind2sub(axes(V), i))
409-
410-
function pointer(V::SubArray{T,N,<:Array,<:Tuple{Vararg{RangeIndex}}}, is::Tuple{Vararg{Int}}) where {T,N}
411-
index = first_index(V)
412-
strds = strides(V)
413-
for d = 1:length(is)
414-
index += (is[d]-1)*strds[d]
415-
end
416-
return pointer(V.parent, index)
417-
end
418407

419408
# indices are taken from the range/vector
420409
# Since bounds-checking is performance-critical and uses

stdlib/LinearAlgebra/src/adjtrans.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ Base.strides(A::Transpose{<:Any, <:StridedMatrix}) = reverse(strides(A.parent))
208208
Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
209209
Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent)
210210

211+
Base.elsize(::Type{<:Adjoint{<:Real, P}}) where {P<:StridedVecOrMat} = Base.elsize(P)
212+
Base.elsize(::Type{<:Transpose{<:Any, P}}) where {P<:StridedVecOrMat} = Base.elsize(P)
213+
211214
# for vectors, the semantics of the wrapped and unwrapped types differ
212215
# so attempt to maintain both the parent and wrapper type insofar as possible
213216
similar(A::AdjOrTransAbsVec) = wrapperop(A)(similar(A.parent))

test/abstractarray.jl

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,3 +978,144 @@ end
978978
@test Core.sizeof(arrayOfUInt48) == 24
979979
end
980980
end
981+
982+
struct Strider{T,N} <: AbstractArray{T,N}
983+
data::Vector{T}
984+
offset::Int
985+
strides::NTuple{N,Int}
986+
size::NTuple{N,Int}
987+
end
988+
function Strider{T}(strides::NTuple{N}, size::NTuple{N}) where {T,N}
989+
offset = 1-sum(strides .* (strides .< 0) .* (size .- 1))
990+
data = Array{T}(undef, sum(abs.(strides) .* (size .- 1)) + 1)
991+
return Strider{T, N, Vector{T}}(data, offset, strides, size)
992+
end
993+
function Strider(vec::AbstractArray{T}, strides::NTuple{N}, size::NTuple{N}) where {T,N}
994+
offset = 1-sum(strides .* (strides .< 0) .* (size .- 1))
995+
@assert length(vec) >= sum(abs.(strides) .* (size .- 1)) + 1
996+
return Strider{T, N}(vec, offset, strides, size)
997+
end
998+
Base.size(S::Strider) = S.size
999+
function Base.getindex(S::Strider{<:Any,N}, I::Vararg{Int,N}) where {N}
1000+
return S.data[sum(S.strides .* (I .- 1)) + S.offset]
1001+
end
1002+
Base.strides(S::Strider) = S.strides
1003+
Base.elsize(::Type{<:Strider{T}}) where {T} = Base.elsize(Vector{T})
1004+
Base.unsafe_convert(::Type{Ptr{T}}, S::Strider{T}) where {T} = pointer(S.data, S.offset)
1005+
1006+
@testset "Simple 3d strided views and permutes" for sz in ((5, 3, 2), (7, 11, 13))
1007+
A = collect(reshape(1:prod(sz), sz))
1008+
S = Strider(vec(A), strides(A), sz)
1009+
@test pointer(A) == pointer(S)
1010+
for i in 1:prod(sz)
1011+
@test pointer(A, i) == pointer(S, i)
1012+
@test A[i] == S[i]
1013+
end
1014+
for idxs in ((1:sz[1], 1:sz[2], 1:sz[3]),
1015+
(1:sz[1], 2:2:sz[2], sz[3]:-1:1),
1016+
(2:2:sz[1]-1, sz[2]:-1:1, sz[3]:-2:2),
1017+
(sz[1]:-1:1, sz[2]:-1:1, sz[3]:-1:1),
1018+
(sz[1]-1:-3:1, sz[2]:-2:3, 1:sz[3]),)
1019+
Ai = A[idxs...]
1020+
Av = view(A, idxs...)
1021+
Sv = view(S, idxs...)
1022+
Ss = Strider{Int, 3}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Av), length.(idxs))
1023+
@test pointer(Av) == pointer(Sv) == pointer(Ss)
1024+
for i in 1:length(Av)
1025+
@test pointer(Av, i) == pointer(Sv, i) == pointer(Ss, i)
1026+
@test Ai[i] == Av[i] == Sv[i] == Ss[i]
1027+
end
1028+
for perm in ((3, 2, 1), (2, 1, 3), (3, 1, 2))
1029+
P = permutedims(A, perm)
1030+
Ap = Base.PermutedDimsArray(A, perm)
1031+
Sp = Base.PermutedDimsArray(S, perm)
1032+
Ps = Strider{Int, 3}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)])
1033+
@test pointer(Ap) == pointer(Sp) == pointer(Ps)
1034+
for i in 1:length(Ap)
1035+
# This is intentionally disabled due to ambiguity
1036+
@test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i)
1037+
@test P[i] == Ap[i] == Sp[i] == Ps[i]
1038+
end
1039+
Pv = view(P, idxs[collect(perm)]...)
1040+
Pi = P[idxs[collect(perm)]...]
1041+
Apv = view(Ap, idxs[collect(perm)]...)
1042+
Spv = view(Sp, idxs[collect(perm)]...)
1043+
Pvs = Strider{Int, 3}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv))
1044+
@test pointer(Apv) == pointer(Spv) == pointer(Pvs)
1045+
for i in 1:length(Apv)
1046+
@test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i)
1047+
@test Pi[i] == Pv[i] == Apv[i] == Spv[i] == Pvs[i]
1048+
end
1049+
Vp = permutedims(Av, perm)
1050+
Ip = permutedims(Ai, perm)
1051+
Avp = Base.PermutedDimsArray(Av, perm)
1052+
Svp = Base.PermutedDimsArray(Sv, perm)
1053+
@test pointer(Avp) == pointer(Svp)
1054+
for i in 1:length(Avp)
1055+
# This is intentionally disabled due to ambiguity
1056+
@test_broken pointer(Avp, i) == pointer(Svp, i)
1057+
@test Ip[i] == Vp[i] == Avp[i] == Svp[i]
1058+
end
1059+
end
1060+
end
1061+
end
1062+
1063+
@testset "simple 2d strided views, permutes, transposes" for sz in ((5, 3), (7, 11))
1064+
A = collect(reshape(1:prod(sz), sz))
1065+
S = Strider(vec(A), strides(A), sz)
1066+
@test pointer(A) == pointer(S)
1067+
for i in 1:prod(sz)
1068+
@test pointer(A, i) == pointer(S, i)
1069+
@test A[i] == S[i]
1070+
end
1071+
for idxs in ((1:sz[1], 1:sz[2]),
1072+
(1:sz[1], 2:2:sz[2]),
1073+
(2:2:sz[1]-1, sz[2]:-1:1),
1074+
(sz[1]:-1:1, sz[2]:-1:1),
1075+
(sz[1]-1:-3:1, sz[2]:-2:3),)
1076+
Av = view(A, idxs...)
1077+
Sv = view(S, idxs...)
1078+
Ss = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Av), length.(idxs))
1079+
@test pointer(Av) == pointer(Sv) == pointer(Ss)
1080+
for i in 1:length(Av)
1081+
@test pointer(Av, i) == pointer(Sv, i) == pointer(Ss, i)
1082+
@test Av[i] == Sv[i] == Ss[i]
1083+
end
1084+
perm = (2, 1)
1085+
P = permutedims(A, perm)
1086+
Ap = Base.PermutedDimsArray(A, perm)
1087+
At = transpose(A)
1088+
Aa = adjoint(A)
1089+
Sp = Base.PermutedDimsArray(S, perm)
1090+
Ps = Strider{Int, 2}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)])
1091+
@test pointer(Ap) == pointer(Sp) == pointer(Ps) == pointer(At) == pointer(Aa)
1092+
for i in 1:length(Ap)
1093+
# This is intentionally disabled due to ambiguity
1094+
@test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) == pointer(At, i) == pointer(Aa, i)
1095+
@test pointer(Ps, i) == pointer(At, i) == pointer(Aa, i)
1096+
@test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i]
1097+
end
1098+
Pv = view(P, idxs[collect(perm)]...)
1099+
Apv = view(Ap, idxs[collect(perm)]...)
1100+
Atv = view(At, idxs[collect(perm)]...)
1101+
Ata = view(Aa, idxs[collect(perm)]...)
1102+
Spv = view(Sp, idxs[collect(perm)]...)
1103+
Pvs = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv))
1104+
@test pointer(Apv) == pointer(Spv) == pointer(Pvs) == pointer(Atv) == pointer(Ata)
1105+
for i in 1:length(Apv)
1106+
@test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) == pointer(Atv, i) == pointer(Ata, i)
1107+
@test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i]
1108+
end
1109+
Vp = permutedims(Av, perm)
1110+
Avp = Base.PermutedDimsArray(Av, perm)
1111+
Avt = transpose(Av)
1112+
Ava = adjoint(Av)
1113+
Svp = Base.PermutedDimsArray(Sv, perm)
1114+
@test pointer(Avp) == pointer(Svp) == pointer(Avt) == pointer(Ava)
1115+
for i in 1:length(Avp)
1116+
# This is intentionally disabled due to ambiguity
1117+
@test_broken pointer(Avp, i) == pointer(Svp, i) == pointer(Avt, i) == pointer(Ava, i)
1118+
@test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i]
1119+
end
1120+
end
1121+
end

0 commit comments

Comments
 (0)