Skip to content

Commit 84cfe04

Browse files
authored
add wrap function which is the safe counterpart to unsafe_wrap. (#52049)
1 parent abeb68f commit 84cfe04

File tree

4 files changed

+54
-0
lines changed

4 files changed

+54
-0
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ New library functions
6262
* `copyuntil(out, io, delim)` and `copyline(out, io)` copy data into an `out::IO` stream ([#48273]).
6363
* `eachrsplit(string, pattern)` iterates split substrings right to left.
6464
* `Sys.username()` can be used to return the current user's username ([#51897]).
65+
* `wrap(Array, m::Union{MemoryRef{T}, Memory{T}}, dims)` which is the safe counterpart to `unsafe_wrap` ([#52049]).
6566

6667
New library features
6768
--------------------

base/array.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3039,3 +3039,36 @@ intersect(r::AbstractRange, v::AbstractVector) = intersect(v, r)
30393039
_getindex(v, i)
30403040
end
30413041
end
3042+
3043+
"""
3044+
wrap(Array, m::Union{Memory{T}, MemoryRef{T}}, dims)
3045+
3046+
Create an array of size `dims` using `m` as the underlying memory. This can be thought of as a safe version
3047+
of [`unsafe_wrap`](@ref) utilizing `Memory` or `MemoryRef` instead of raw pointers.
3048+
"""
3049+
function wrap end
3050+
3051+
@eval @propagate_inbounds function wrap(::Type{Array}, ref::MemoryRef{T}, dims::NTuple{N, Integer}) where {T, N}
3052+
mem = ref.mem
3053+
mem_len = length(mem) + 1 - memoryrefoffset(ref)
3054+
len = Core.checked_dims(dims...)
3055+
@boundscheck mem_len >= len || invalid_wrap_err(mem_len, dims, len)
3056+
if N != 1 && !(ref === GenericMemoryRef(mem) && len === mem_len)
3057+
mem = ccall(:jl_genericmemory_slice, Memory{T}, (Any, Ptr{Cvoid}, Int), mem, ref.ptr_or_offset, len)
3058+
ref = MemoryRef(mem)
3059+
end
3060+
$(Expr(:new, :(Array{T, N}), :ref, :dims))
3061+
end
3062+
3063+
@noinline invalid_wrap_err(len, dims, proddims) = throw(DimensionMismatch(
3064+
"Attempted to wrap a MemoryRef of length $len with an Array of size dims=$dims, which is invalid because prod(dims) = $proddims > $len, so that the array would have more elements than the underlying memory can store."))
3065+
3066+
function wrap(::Type{Array}, m::Memory{T}, dims::NTuple{N, Integer}) where {T, N}
3067+
wrap(Array, MemoryRef(m), dims)
3068+
end
3069+
function wrap(::Type{Array}, m::MemoryRef{T}, l::Integer) where {T}
3070+
wrap(Array, m, (l,))
3071+
end
3072+
function wrap(::Type{Array}, m::Memory{T}, l::Integer) where {T}
3073+
wrap(Array, MemoryRef(m), (l,))
3074+
end

base/exports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ export
458458
vcat,
459459
vec,
460460
view,
461+
wrap,
461462
zeros,
462463

463464
# search, find, match and related functions

test/arrayops.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3170,3 +3170,22 @@ end
31703170
@test c + zero(c) == c
31713171
end
31723172
end
3173+
3174+
@testset "Wrapping Memory into Arrays" begin
3175+
mem = Memory{Int}(undef, 10) .= 1
3176+
memref = MemoryRef(mem)
3177+
@test_throws DimensionMismatch wrap(Array, mem, (10, 10))
3178+
@test wrap(Array, mem, (5,)) == ones(Int, 5)
3179+
@test wrap(Array, mem, 2) == ones(Int, 2)
3180+
@test wrap(Array, memref, 10) == ones(Int, 10)
3181+
@test wrap(Array, memref, (2,2,2)) == ones(Int,2,2,2)
3182+
@test wrap(Array, mem, (5, 2)) == ones(Int, 5, 2)
3183+
3184+
memref2 = MemoryRef(mem, 3)
3185+
@test wrap(Array, memref2, (5,)) == ones(Int, 5)
3186+
@test wrap(Array, memref2, 2) == ones(Int, 2)
3187+
@test wrap(Array, memref2, (2,2,2)) == ones(Int,2,2,2)
3188+
@test wrap(Array, memref2, (3, 2)) == ones(Int, 3, 2)
3189+
@test_throws DimensionMismatch wrap(Array, memref2, 9)
3190+
@test_throws DimensionMismatch wrap(Array, memref2, 10)
3191+
end

0 commit comments

Comments
 (0)