Skip to content

Commit 253f6b3

Browse files
authored
Support unsafe_wrap on arrays of symbols (#2753)
1 parent 3d44e32 commit 253f6b3

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

src/array.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function explain_eltype(@nospecialize(T), depth=0; maxdepth=10)
2424
msg *= explain_eltype(U, depth+1)
2525
end
2626
end
27-
elseif Base.ismutabletype(T)
27+
elseif Base.ismutabletype(T) && Base.datatype_fieldcount(T) != 0
2828
msg = " "^depth * "$T is a mutable type\n"
2929
elseif hasfieldcount(T)
3030
msg = " "^depth * "$T is a struct that's not allocated inline\n"
@@ -47,9 +47,22 @@ end
4747
# these are stored with a selector at the end (handled by Julia).
4848
# 3. bitstype unions (`Union{Int, Float32}`, etc)
4949
# these are stored contiguously and require a selector array (handled by us)
50-
@inline function check_eltype(name, T)
51-
eltype_is_invalid = !Base.allocatedinline(T) || (hasfieldcount(T) && any(!Base.allocatedinline, fieldtypes(T)))
52-
if eltype_is_invalid
50+
# As well as "mutable singleton" types like `Symbol` that use pointer-identity
51+
52+
function valid_type(@nospecialize(T))
53+
if Base.allocatedinline(T)
54+
if hasfieldcount(T)
55+
return all(valid_type, fieldtypes(T))
56+
end
57+
return true
58+
elseif Base.ismutabletype(T)
59+
return Base.datatype_fieldcount(T) == 0
60+
end
61+
return false
62+
end
63+
64+
@inline function check_eltype(name, T)
65+
if !valid_type(T)
5366
explanation = explain_eltype(T)
5467
error("""
5568
$name only supports element types that are allocated inline.
@@ -234,7 +247,7 @@ end
234247
function Base.unsafe_wrap(::Type{CuArray{T,N,M}},
235248
ptr::CuPtr{T}, dims::NTuple{N,Int};
236249
own::Bool=false, ctx::CuContext=context()) where {T,N,M}
237-
isbitstype(T) || throw(ArgumentError("Can only unsafe_wrap a pointer to a bits type"))
250+
check_eltype("unsafe_wrap(CuArray, ...)", T)
238251
sz = prod(dims) * aligned_sizeof(T)
239252

240253
# create a memory object

test/base/array.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,18 @@ end
173173
cpu_arr = unsafe_wrap(Array, cpu_ptr, 1)
174174
@test cpu_arr == [42]
175175
end
176+
177+
# symbols and tuples thereof
178+
let a = CuArray([:a])
179+
b = unsafe_wrap(CuArray, pointer(a), 1)
180+
@test typeof(b) <: CuArray{Symbol,1}
181+
@test size(b) == (1,)
182+
end
183+
let a = CuArray([(:a,:b)])
184+
b = unsafe_wrap(CuArray, pointer(a), 1)
185+
@test typeof(b) <: CuArray{Tuple{Symbol,Symbol},1}
186+
@test size(b) == (1,)
187+
end
176188
end
177189

178190
@testset "adapt" begin

0 commit comments

Comments
 (0)