Skip to content

Commit 8492254

Browse files
committed
fix cartesian indexing into ConstCuDeviceArray
1 parent fe652e1 commit 8492254

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

examples/performance.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ end
7979
NVTX.@range "Naive transpose 1, $(block_dim^2)" let
8080
a = CuArray(rand(T, shape))
8181
b = similar(a, shape[2], shape[1])
82-
kernel! = transpose_kernel_naive!(CUDA(), (1, blockdim*block_dim), size(b))
82+
kernel! = transpose_kernel_naive!(CUDA(), (1, block_dim*block_dim), size(b))
8383

8484
event = kernel!(b, a)
8585
wait(event)

src/backends/cuda.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ Base.pointer(a::ConstCuDeviceArray, i::Integer) =
211211
Base.elsize(::Type{<:ConstCuDeviceArray{T}}) where {T} = sizeof(T)
212212
Base.size(g::ConstCuDeviceArray) = g.shape
213213
Base.length(g::ConstCuDeviceArray) = prod(g.shape)
214+
Base.IndexStyle(::Type{<:ConstCuDeviceArray}) = Base.IndexLinear()
214215

215216
Base.unsafe_convert(::Type{DevicePtr{T,A}}, a::ConstCuDeviceArray{T,N,A}) where {T,A,N} = pointer(a)
216217

@@ -219,3 +220,9 @@ Base.unsafe_convert(::Type{DevicePtr{T,A}}, a::ConstCuDeviceArray{T,N,A}) where
219220
align = Base.datatype_alignment(T)
220221
CUDAnative.unsafe_cached_load(pointer(A), index, Val(align))::T
221222
end
223+
224+
@inline function Base.unsafe_view(arr::ConstCuDeviceArray{T, 1, A}, I::Vararg{Base.ViewIndex,1}) where {T, A}
225+
ptr = pointer(arr) + (I[1].start-1)*sizeof(T)
226+
len = I[1].stop - I[1].start + 1
227+
return ConstCuDeviceArray{T,1,A}(len, ptr)
228+
end

0 commit comments

Comments
 (0)