@@ -240,6 +240,102 @@ __launch_bounds__(kMaxThreads) void reorder_batched_ad_indices_kernel(
240
240
}
241
241
}
242
242
243
+ template <typename Dtype, typename index_t = int32_t >
244
+ __global__
245
+ __launch_bounds__ (fbgemm_gpu::kMaxThreads ) void reorder_batched_ad_indices_kernel_vec (
246
+ // reorder indices from (ragged) [B x T x #num_ads_b x length_{b, t, a})]
247
+ // to [T][B][#num_ads_b][length_{b, t, a}], i.e. [sum(length_{b, t, a})],
248
+ // laid out as [T][B][A][L] (if all lengths were equal).
249
+
250
+ // if broadcast_indices is enabled, all the indices will be copies of the
251
+ // first batch of the cat_ad_indices, this is useful for request-only
252
+ // broadcast
253
+ const pta::PackedTensorAccessor32<index_t , 1 , at::RestrictPtrTraits>
254
+ cat_ad_offsets,
255
+ const pta::PackedTensorAccessor32<Dtype, 1 , at::RestrictPtrTraits>
256
+ cat_ad_indices,
257
+ const pta::PackedTensorAccessor32<index_t , 1 , at::RestrictPtrTraits>
258
+ reordered_cat_ad_offsets,
259
+ pta::PackedTensorAccessor32<Dtype, 1 , at::RestrictPtrTraits>
260
+ reordered_cat_ad_indices,
261
+ const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits>
262
+ batch_offsets,
263
+ const int32_t T,
264
+ const bool broadcast_indices) {
265
+ using vec2_t =
266
+ typename std::conditional<sizeof (Dtype) == 8 , long2 , float2 >::type;
267
+ using vec4_t =
268
+ typename std::conditional<sizeof (Dtype) == 8 , long4 , float4 >::type;
269
+ const int32_t B = batch_offsets.size (0 ) - 1 ;
270
+ const int32_t num_ads_in_batch = batch_offsets[B];
271
+ // warp-per-segment.
272
+ const auto b_t = blockIdx .x * blockDim .y +
273
+ threadIdx .y ; // can be more efficient through bitwise op
274
+ const int32_t b = b_t % B;
275
+ const int32_t t = b_t / B;
276
+ if (t >= T) {
277
+ return ;
278
+ }
279
+
280
+ const auto num_ads_b = batch_offsets[b + 1 ] - batch_offsets[b];
281
+ const auto output_segment_offset_start =
282
+ t * num_ads_in_batch + batch_offsets[b];
283
+ const auto output_segment_start =
284
+ reordered_cat_ad_offsets[output_segment_offset_start];
285
+ const int32_t input_segment_offset_start =
286
+ broadcast_indices ? T * b + t : T * batch_offsets[b] + t * num_ads_b;
287
+ const int32_t input_segment_offset_end = broadcast_indices
288
+ ? input_segment_offset_start + 1
289
+ : input_segment_offset_start + num_ads_b;
290
+ const auto input_segment_start = cat_ad_offsets[input_segment_offset_start];
291
+ const auto input_segment_end = cat_ad_offsets[input_segment_offset_end];
292
+ const auto num_elements = input_segment_end - input_segment_start;
293
+
294
+ if (broadcast_indices) {
295
+ for (auto i = threadIdx .x ; i < num_ads_b * num_elements; i += blockDim .x ) {
296
+ reordered_cat_ad_indices[output_segment_start + i] =
297
+ cat_ad_indices[input_segment_start + i % num_elements];
298
+ }
299
+ } else {
300
+ // Idea: we want to copy the entire segment of size sum_a(length_{b, t, a})
301
+ // from starting point (given by cat_ad_offsets[b, t])
302
+ // to end point (given by reordered_cat_ad_indices[t][b])
303
+ if (num_elements <= 64 ) {
304
+ for (auto i = threadIdx .x ; i < input_segment_end - input_segment_start;
305
+ i += blockDim .x ) {
306
+ // coalesced global memory access, can be optimzed through ILP with the
307
+ // help of shared memory or vector load/store (if num_ads_b>=64)
308
+ reordered_cat_ad_indices[output_segment_start + i] =
309
+ cat_ad_indices[input_segment_start + i];
310
+ }
311
+ } else if (num_elements > 64 && num_elements <= 128 ) {
312
+ auto dst =
313
+ (vec2_t *)(reordered_cat_ad_indices.data () + output_segment_start);
314
+ auto src = (vec2_t *)(cat_ad_indices.data () + input_segment_start);
315
+ for (auto i = threadIdx .x ; i < num_elements / 2 ; i += blockDim .x ) {
316
+ dst[i] = src[i];
317
+ }
318
+ if ((num_elements % 2 ) && threadIdx .x == 31 ) {
319
+ reordered_cat_ad_indices[output_segment_start + num_elements - 1 ] =
320
+ cat_ad_indices[input_segment_start + num_elements - 1 ];
321
+ }
322
+ } else if (num_elements > 128 ) {
323
+ auto dst =
324
+ (vec4_t *)(reordered_cat_ad_indices.data () + output_segment_start);
325
+ auto src = (vec4_t *)(cat_ad_indices.data () + input_segment_start);
326
+ for (auto i = threadIdx .x ; i < num_elements / 4 ; i += blockDim .x ) {
327
+ dst[i] = src[i];
328
+ }
329
+ int remainder = num_elements % 4 ;
330
+ if (remainder && threadIdx .x < remainder) {
331
+ reordered_cat_ad_indices
332
+ [output_segment_start + num_elements - threadIdx .x - 1 ] =
333
+ cat_ad_indices
334
+ [input_segment_start + num_elements - threadIdx .x - 1 ];
335
+ }
336
+ }
337
+ }
338
+ }
243
339
DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu (
244
340
const Tensor& cat_ad_offsets,
245
341
const Tensor& cat_ad_indices,
@@ -387,6 +483,7 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
387
483
C10_CUDA_KERNEL_LAUNCH_CHECK ();
388
484
});
389
485
});
486
+
390
487
return reordered_cat_ad_indices;
391
488
}
392
489
0 commit comments