Skip to content

Commit ba14df1

Browse files
q10facebook-github-bot
authored andcommitted
Migrate metric ops kernels to FBGEMM_LAUNCH_KERNEL (#4453)
Summary: Pull Request resolved: #4453 - Migrate metric ops kernels to `FBGEMM_LAUNCH_KERNEL` Reviewed By: r-barnes Differential Revision: D75119867 fbshipit-source-id: eb7237685e6f0c5020c2b77c626eaf7f56310976
1 parent 92de1de commit ba14df1

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

fbgemm_gpu/src/metric_ops/metric_ops.cu

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "fbgemm_gpu/utils/cuda_prelude.cuh"
1818
#include "fbgemm_gpu/utils/dispatch_macros.h"
1919
#include "fbgemm_gpu/utils/inclusive_sum_scan.cuh"
20+
#include "fbgemm_gpu/utils/kernel_launcher.cuh"
2021
#include "metric_ops.h"
2122

2223
constexpr int MAX_ENTRIES_PER_BLOCK = 512;
@@ -251,28 +252,29 @@ at::Tensor batch_auc(
251252
auto max_smem_size =
252253
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
253254

254-
#define LAUNCH_AUC_KERNEL(pad) \
255-
typedef cub::BlockScan<acc_t, NUM_THREADS_PER_BLOCK> BlockScan; \
256-
TORCH_CHECK( \
257-
sizeof(BlockScan::TempStorage) + \
258-
((MAX_ENTRIES_PER_BLOCK * 2 + 3) * sizeof(acc_t)) <= \
259-
max_smem_size) \
260-
auc_kernel<index_t, label_t, scalar_t, acc_t, pad> \
261-
<<<dim3(grid_size), \
262-
dim3(NUM_THREADS_PER_BLOCK), \
263-
0, \
264-
at::cuda::getCurrentCUDAStream()>>>( \
265-
output.data_ptr<acc_t>(), \
266-
indices.data_ptr<index_t>(), \
267-
labels.data_ptr<label_t>(), \
268-
weights.data_ptr<scalar_t>(), \
269-
num_blocks > 1 ? block_flags.data_ptr<int>() : nullptr, \
270-
num_blocks > 1 ? block_sums.data_ptr<acc_t>() : nullptr, \
271-
num_entries, \
272-
last_block_num_entries, \
273-
padded_num_entries_per_block, \
274-
num_blocks); \
275-
C10_CUDA_KERNEL_LAUNCH_CHECK();
255+
#define LAUNCH_AUC_KERNEL(pad) \
256+
typedef cub::BlockScan<acc_t, NUM_THREADS_PER_BLOCK> BlockScan; \
257+
TORCH_CHECK( \
258+
sizeof(BlockScan::TempStorage) + \
259+
((MAX_ENTRIES_PER_BLOCK * 2 + 3) * sizeof(acc_t)) <= \
260+
max_smem_size) \
261+
\
262+
FBGEMM_LAUNCH_KERNEL( \
263+
(auc_kernel<index_t, label_t, scalar_t, acc_t, pad>), \
264+
dim3(grid_size), \
265+
dim3(NUM_THREADS_PER_BLOCK), \
266+
0, \
267+
at::cuda::getCurrentCUDAStream(), \
268+
output.data_ptr<acc_t>(), \
269+
indices.data_ptr<index_t>(), \
270+
labels.data_ptr<label_t>(), \
271+
weights.data_ptr<scalar_t>(), \
272+
num_blocks > 1 ? block_flags.data_ptr<int>() : nullptr, \
273+
num_blocks > 1 ? block_sums.data_ptr<acc_t>() : nullptr, \
274+
num_entries, \
275+
last_block_num_entries, \
276+
padded_num_entries_per_block, \
277+
num_blocks);
276278

277279
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "auc_wrapper_1", [&] {
278280
FBGEMM_DISPATCH_ALL_TYPES(labels.scalar_type(), "auc_wrapper_2", [&] {
@@ -285,7 +287,6 @@ at::Tensor batch_auc(
285287
} else {
286288
LAUNCH_AUC_KERNEL(2)
287289
}
288-
C10_CUDA_KERNEL_LAUNCH_CHECK();
289290
});
290291
});
291292
});

0 commit comments

Comments
 (0)