@@ -1213,10 +1213,27 @@ end
1213
1213
1214
1214
# create a batch of pointers in device memory from a strided device array
1215
1215
@inline function unsafe_strided_batch (strided:: DenseCuArray{T} ) where {T}
1216
- batchsize = last (size (strided))
1217
- stride = prod (size (strided)[1 : end - 1 ])
1218
- ptrs = [pointer (strided, (i- 1 )* stride + 1 ) for i in 1 : batchsize]
1219
- return CuArray (ptrs)
1216
+ batch_size = last (size (strided))
1217
+ batch_stride = prod (size (strided)[1 : end - 1 ])
1218
+ # ptrs = [pointer(strided, (i-1)*batch_stride + 1) for i in 1:batch_size]
1219
+ # fill the array on the GPU to avoid synchronous copies and support larger batch sizes
1220
+ ptrs = CuArray {CuPtr{T}} (undef, batch_size)
1221
+ function compute_pointers ()
1222
+ i = (blockIdx (). x - 1 i32) * blockDim (). x + threadIdx (). x
1223
+ grid_stride = gridDim (). x * blockDim (). x
1224
+ while i <= length (ptrs)
1225
+ @inbounds ptrs[i] =
1226
+ reinterpret (CuPtr{T}, pointer (strided, (i - 1 i32) * batch_stride + 1 i32))
1227
+ i += grid_stride
1228
+ end
1229
+ return
1230
+ end
1231
+ kernel = @cuda launch = false compute_pointers ()
1232
+ config = launch_configuration (kernel. fun)
1233
+ threads = min (config. threads, batch_size)
1234
+ blocks = min (config. blocks, cld (batch_size, threads))
1235
+ @cuda threads blocks compute_pointers ()
1236
+ return ptrs
1220
1237
end
1221
1238
1222
1239
# # (GE) general matrix-matrix multiplication grouped batched
0 commit comments