17
17
#include " fbgemm_gpu/utils/cuda_prelude.cuh"
18
18
#include " fbgemm_gpu/utils/dispatch_macros.h"
19
19
#include " fbgemm_gpu/utils/inclusive_sum_scan.cuh"
20
+ #include " fbgemm_gpu/utils/kernel_launcher.cuh"
20
21
#include " metric_ops.h"
21
22
22
23
constexpr int MAX_ENTRIES_PER_BLOCK = 512 ;
@@ -251,28 +252,29 @@ at::Tensor batch_auc(
251
252
auto max_smem_size =
252
253
at::cuda::getCurrentDeviceProperties ()->sharedMemPerBlock ;
253
254
254
- #define LAUNCH_AUC_KERNEL (pad ) \
255
- typedef cub::BlockScan<acc_t , NUM_THREADS_PER_BLOCK> BlockScan; \
256
- TORCH_CHECK ( \
257
- sizeof (BlockScan::TempStorage) + \
258
- ((MAX_ENTRIES_PER_BLOCK * 2 + 3 ) * sizeof (acc_t )) <= \
259
- max_smem_size) \
260
- auc_kernel<index_t , label_t , scalar_t , acc_t , pad> \
261
- <<<dim3 (grid_size), \
262
- dim3 (NUM_THREADS_PER_BLOCK), \
263
- 0, \
264
- at::cuda::getCurrentCUDAStream()>>>( \
265
- output.data_ptr<acc_t >(), \
266
- indices.data_ptr<index_t>(), \
267
- labels.data_ptr<label_t>(), \
268
- weights.data_ptr<scalar_t>(), \
269
- num_blocks > 1 ? block_flags.data_ptr<int>() : nullptr, \
270
- num_blocks > 1 ? block_sums.data_ptr<acc_t>() : nullptr, \
271
- num_entries, \
272
- last_block_num_entries, \
273
- padded_num_entries_per_block, \
274
- num_blocks); \
275
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
255
+ #define LAUNCH_AUC_KERNEL (pad ) \
256
+ typedef cub::BlockScan<acc_t , NUM_THREADS_PER_BLOCK> BlockScan; \
257
+ TORCH_CHECK ( \
258
+ sizeof (BlockScan::TempStorage) + \
259
+ ((MAX_ENTRIES_PER_BLOCK * 2 + 3 ) * sizeof (acc_t )) <= \
260
+ max_smem_size) \
261
+ \
262
+ FBGEMM_LAUNCH_KERNEL ( \
263
+ (auc_kernel<index_t , label_t , scalar_t , acc_t , pad>), \
264
+ dim3 (grid_size), \
265
+ dim3 (NUM_THREADS_PER_BLOCK), \
266
+ 0 , \
267
+ at::cuda::getCurrentCUDAStream (), \
268
+ output.data_ptr <acc_t >(), \
269
+ indices.data_ptr <index_t >(), \
270
+ labels.data_ptr <label_t >(), \
271
+ weights.data_ptr <scalar_t >(), \
272
+ num_blocks > 1 ? block_flags.data_ptr <int >() : nullptr , \
273
+ num_blocks > 1 ? block_sums.data_ptr <acc_t >() : nullptr , \
274
+ num_entries, \
275
+ last_block_num_entries, \
276
+ padded_num_entries_per_block, \
277
+ num_blocks);
276
278
277
279
AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " auc_wrapper_1" , [&] {
278
280
FBGEMM_DISPATCH_ALL_TYPES (labels.scalar_type (), " auc_wrapper_2" , [&] {
@@ -285,7 +287,6 @@ at::Tensor batch_auc(
285
287
} else {
286
288
LAUNCH_AUC_KERNEL (2 )
287
289
}
288
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
289
290
});
290
291
});
291
292
});
0 commit comments