Skip to content

Commit daa1580

Browse files
authored
Add reshape for CuDeviceArray (#1561)
1 parent 85ac4ac commit daa1580

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

src/device/array.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,22 @@ function Base.reinterpret(::Type{T}, a::CuDeviceArray{S,N,A}) where {T,S,N,A}
265265
osize = tuple(size1, Base.tail(isize)...)
266266
return CuDeviceArray{T,N,A}(osize, reinterpret(LLVMPtr{T,A}, a.ptr), a.maxsize)
267267
end
268+
269+
270+
## reshape
271+
272+
function Base.reshape(a::CuDeviceArray{T,M}, dims::NTuple{N,Int}) where {T,N,M}
273+
if prod(dims) != length(a)
274+
throw(DimensionMismatch("new dimensions (argument `dims`) must be consistent with array size (`size(a)`)"))
275+
end
276+
if N == M && dims == size(a)
277+
return a
278+
end
279+
_derived_array(T, N, a, dims)
280+
end
281+
282+
# create a derived device array (reinterpreted or reshaped) that's still a CuDeviceArray
283+
@inline function _derived_array(::Type{T}, N::Int, a::CuDeviceArray{T,M,A}, osize::Dims) where {T, M, A}
284+
return CuDeviceArray{T,N,A}(osize, a.ptr, a.maxsize)
285+
end
286+

test/device/array.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,29 @@ end
137137
@test array == Array(array_dev)
138138
end
139139

140+
@testset "reshape" begin
141+
function kernel(array)
142+
i = (blockIdx().x-1i32) * blockDim().x + threadIdx().x
143+
j = (blockIdx().y-1i32) * blockDim().y + threadIdx().y
144+
145+
_array2d = reshape(array, 10, 10)
146+
_array2d[i,j] = i + (j-1)*size(_array2d,1)
147+
148+
return
149+
end
150+
151+
array = zeros(Int64, 100)
152+
array_dev = CuArray(array)
153+
154+
array2d = reshape(array, 10, 10)
155+
for i in 1:size(array2d,1), j in 1:size(array2d,2)
156+
array2d[i,j] = i + (j-1)*size(array2d,1)
157+
end
158+
159+
@cuda threads=(10, 10) kernel(array_dev)
160+
@test array == Array(array_dev)
161+
end
162+
140163
@testset "non-Int index to unsafe_load" begin
141164
function kernel(a)
142165
a[UInt64(1)] = 1

0 commit comments

Comments
 (0)