Skip to content

Commit 6f0abb0

Browse files
q10facebook-github-bot
authored andcommitted
Remove debug_synchronous from CUB call sites in FBGEMM ops (#1973)
Summary: Pull Request resolved: #1973 - Remove debug_synchronous from CUB call sites in FBGEMM ops Reviewed By: sryap Differential Revision: D48722495 fbshipit-source-id: 47a92dc82e9fe271d719913f8a842f7fa2c8f36f
1 parent 0da4234 commit 6f0abb0

File tree

3 files changed

+38
-13
lines changed

3 files changed

+38
-13
lines changed

fbgemm_gpu/codegen/embedding_backward_split_template.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -524,8 +524,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
524524
linear_indices.numel(),
525525
0,
526526
total_hash_size_bits,
527-
at::cuda::getCurrentCUDAStream(),
528-
false));
527+
at::cuda::getCurrentCUDAStream()));
529528
auto temp_storage = at::empty(
530529
{static_cast<int64_t>(temp_storage_bytes)},
531530
indices.options().dtype(at::kByte));
@@ -539,8 +538,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
539538
linear_indices.numel(),
540539
0,
541540
total_hash_size_bits,
542-
at::cuda::getCurrentCUDAStream(),
543-
false));
541+
at::cuda::getCurrentCUDAStream()));
544542
}
545543
{%- endif %}
546544

@@ -568,8 +566,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
568566
linear_indices.numel(),
569567
0,
570568
total_hash_size_bits,
571-
at::cuda::getCurrentCUDAStream(),
572-
false));
569+
at::cuda::getCurrentCUDAStream()));
573570
auto temp_storage = at::empty(
574571
{static_cast<int64_t>(temp_storage_bytes)},
575572
indices.options().dtype(at::kByte));
@@ -583,8 +580,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
583580
linear_indices.numel(),
584581
0,
585582
total_hash_size_bits,
586-
at::cuda::getCurrentCUDAStream(),
587-
false));
583+
at::cuda::getCurrentCUDAStream()));
588584
}
589585
{%- endif %}
590586

fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ std::tuple<int32_t, uint32_t> adjust_info_B_num_bits(int32_t B, int32_t T);
6969
int num_items, \
7070
int begin_bit = 0, \
7171
int end_bit = sizeof(KeyT) * 8, \
72-
cudaStream_t stream = 0, \
73-
bool debug_synchronous = false)
72+
cudaStream_t stream = 0)
7473

7574
DECL_RADIX_SORT_PAIRS_FN(int64_t, float);
7675
DECL_RADIX_SORT_PAIRS_FN(int64_t, double);

fbgemm_gpu/src/split_embeddings_utils.cu

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
#include "fbgemm_gpu/cub_namespace_postfix.cuh"
2222
// clang-format on
2323

24+
#ifdef __HIP_PLATFORM_HCC__
25+
#include <rocm_version.h>
26+
#endif
27+
2428
inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) {
2529
at::cuda::OptionalCUDAGuard device_guard;
2630
device_guard.set_index(t_in.get_device());
@@ -442,6 +446,32 @@ DLL_PUBLIC std::tuple<int32_t, uint32_t> adjust_info_B_num_bits(
442446
return {info_B_num_bits, info_B_mask};
443447
}
444448
449+
#if defined(CUDA_VERSION) && CUDA_VERSION >= 12000
450+
#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \
451+
DLL_PUBLIC cudaError_t radix_sort_pairs( \
452+
void* d_temp_storage, \
453+
size_t& temp_storage_bytes, \
454+
const KeyT* d_keys_in, \
455+
KeyT* d_keys_out, \
456+
const ValueT* d_values_in, \
457+
ValueT* d_values_out, \
458+
const int num_items, \
459+
const int begin_bit, \
460+
const int end_bit, \
461+
cudaStream_t stream) { \
462+
return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \
463+
d_temp_storage, \
464+
temp_storage_bytes, \
465+
d_keys_in, \
466+
d_keys_out, \
467+
d_values_in, \
468+
d_values_out, \
469+
num_items, \
470+
begin_bit, \
471+
end_bit, \
472+
stream); \
473+
}
474+
#else
445475
#define DEF_RADIX_SORT_PAIRS_FN(KeyT, ValueT) \
446476
DLL_PUBLIC cudaError_t radix_sort_pairs( \
447477
void* d_temp_storage, \
@@ -453,8 +483,7 @@ DLL_PUBLIC std::tuple<int32_t, uint32_t> adjust_info_B_num_bits(
453483
const int num_items, \
454484
const int begin_bit, \
455485
const int end_bit, \
456-
cudaStream_t stream, \
457-
const bool debug_synchronous) { \
486+
cudaStream_t stream) { \
458487
return FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( \
459488
d_temp_storage, \
460489
temp_storage_bytes, \
@@ -466,8 +495,9 @@ DLL_PUBLIC std::tuple<int32_t, uint32_t> adjust_info_B_num_bits(
466495
begin_bit, \
467496
end_bit, \
468497
stream, \
469-
debug_synchronous); \
498+
false); \
470499
}
500+
#endif
471501
472502
DEF_RADIX_SORT_PAIRS_FN(int64_t, float);
473503
DEF_RADIX_SORT_PAIRS_FN(int64_t, double);

0 commit comments

Comments
 (0)