Skip to content

Commit 74b8eff

Browse files
THargreavesmaleadt
andauthored
Move strided batch pointer conversion to GPU (#2608)
Co-authored-by: Tim Besard <tim.besard@gmail.com>
1 parent 14ae82d commit 74b8eff

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

lib/cublas/wrappers.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,10 +1213,27 @@ end
12131213

12141214
# create a batch of pointers in device memory from a strided device array
12151215
@inline function unsafe_strided_batch(strided::DenseCuArray{T}) where {T}
1216-
batchsize = last(size(strided))
1217-
stride = prod(size(strided)[1:end-1])
1218-
ptrs = [pointer(strided, (i-1)*stride + 1) for i in 1:batchsize]
1219-
return CuArray(ptrs)
1216+
batch_size = last(size(strided))
1217+
batch_stride = prod(size(strided)[1:end-1])
1218+
#ptrs = [pointer(strided, (i-1)*batch_stride + 1) for i in 1:batch_size]
1219+
# fill the array on the GPU to avoid synchronous copies and support larger batch sizes
1220+
ptrs = CuArray{CuPtr{T}}(undef, batch_size)
1221+
function compute_pointers()
1222+
i = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
1223+
grid_stride = gridDim().x * blockDim().x
1224+
while i <= length(ptrs)
1225+
@inbounds ptrs[i] =
1226+
reinterpret(CuPtr{T}, pointer(strided, (i - 1i32) * batch_stride + 1i32))
1227+
i += grid_stride
1228+
end
1229+
return
1230+
end
1231+
kernel = @cuda launch = false compute_pointers()
1232+
config = launch_configuration(kernel.fun)
1233+
threads = min(config.threads, batch_size)
1234+
blocks = min(config.blocks, cld(batch_size, threads))
1235+
@cuda threads blocks compute_pointers()
1236+
return ptrs
12201237
end
12211238

12221239
## (GE) general matrix-matrix multiplication grouped batched

0 commit comments

Comments
 (0)