Skip to content

Migrate Torch DSA kernels to FBGEMM_LAUNCH_DSA_KERNEL #4556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 18 additions & 19 deletions fbgemm_gpu/src/sparse_ops/sparse_index_select.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,24 @@ DLL_PUBLIC Tensor index_select_cuda(
const auto dummy_orig_indices =
at::empty({0}, at::TensorOptions().dtype(at::kLong));

#define LAUNCH_INDEX_SELECT(INDICES_SORTED) \
{ \
const auto orig_indices_ = \
INDICES_SORTED ? orig_indices : dummy_orig_indices; \
[[maybe_unused]] const auto func_name = "index_select_2d_kernel"; \
TORCH_DSA_KERNEL_LAUNCH( \
(index_select_2d_kernel< \
index_t, \
scalar_t, \
UNROLL_FACTOR, \
INDICES_SORTED>), \
cuda_calc_xblock_count(N, 1), \
std::min(div_round_up(D, UNROLL_FACTOR), kMaxThreads), \
0, \
at::cuda::getCurrentCUDAStream(), \
MAKE_PTA_WITH_NAME(func_name, input_reshaped, scalar_t, 2, 64), \
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 64), \
MAKE_PTA_WITH_NAME(func_name, orig_indices_, int64_t, 1, 64), \
MAKE_PTA_WITH_NAME(func_name, output, scalar_t, 2, 64)); \
#define LAUNCH_INDEX_SELECT(INDICES_SORTED) \
{ \
const auto orig_indices_ = \
INDICES_SORTED ? orig_indices : dummy_orig_indices; \
FBGEMM_LAUNCH_DSA_KERNEL( \
(index_select_2d_kernel< \
index_t, \
scalar_t, \
UNROLL_FACTOR, \
INDICES_SORTED>), \
cuda_calc_xblock_count(N, 1), \
std::min(div_round_up(D, UNROLL_FACTOR), kMaxThreads), \
0, \
at::cuda::getCurrentCUDAStream(), \
PTA_B(input_reshaped, scalar_t, 2, 64), \
PTA_B(indices, index_t, 1, 64), \
PTA_B(orig_indices_, int64_t, 1, 64), \
PTA_B(output, scalar_t, 2, 64)); \
}

AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "index_add_2d_kernel_1", [&] {
Expand Down
6 changes: 4 additions & 2 deletions fbgemm_gpu/src/sparse_ops/sparse_pack_segments_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ DLL_PUBLIC Tensor pack_segments_forward_cuda(
auto* const out_data = packed_tensor.data_ptr<scalar_t>();
const auto num_seq = lengths.size(0);
const auto cell_size = t_in_c.numel() / t_in_c.size(0);
TORCH_DSA_KERNEL_LAUNCH(

FBGEMM_LAUNCH_DSA_KERNEL(
(pack_segments_cuda_kernel<index_t, scalar_t>),
cuda_calc_xblock_count(num_seq * max_length * cell_size, 128),
128,
Expand Down Expand Up @@ -228,7 +229,8 @@ pack_segments_forward_cuda_v2(
auto* const out_data = packed_tensor.data_ptr<scalar_t>();
const auto num_seq = lengths.size(0);
const auto cell_size = t_in_c.numel() / t_in_c.size(0);
TORCH_DSA_KERNEL_LAUNCH(

FBGEMM_LAUNCH_DSA_KERNEL(
(pack_segments_cuda_v2_kernel<index_t, scalar_t>),
cuda_calc_xblock_count(num_seq * max_length * cell_size, 128),
128,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,48 +407,45 @@ ssd_cache_populate_actions_cuda(

auto actions_count = at::empty({1}, int_options);
// Find uncached indices
auto
[sorted_cache_sets,
cache_set_sorted_unique_indices,
cache_set_inverse_indices] =
lru_cache_find_uncached_cuda(
unique_indices,
unique_indices_length,
total_hash_size,
lxu_cache_state,
time_stamp,
lru_state,
gather_cache_stats,
ssd_cache_stats_,
lock_cache_line,
lxu_cache_locking_counter_,
/*compute_inverse_indices=*/true);
at::Tensor sorted_cache_sets;
at::Tensor cache_set_sorted_unique_indices;
std::optional<at::Tensor> cache_set_inverse_indices;
std::tie(
sorted_cache_sets,
cache_set_sorted_unique_indices,
cache_set_inverse_indices) =
lru_cache_find_uncached_cuda(
unique_indices,
unique_indices_length,
total_hash_size,
lxu_cache_state,
time_stamp,
lru_state,
gather_cache_stats,
ssd_cache_stats_,
lock_cache_line,
lxu_cache_locking_counter_,
/*compute_inverse_indices=*/true);

TORCH_CHECK(cache_set_inverse_indices.has_value());

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "ssd_cache_actions_insert_kernel";
#endif

TORCH_DSA_KERNEL_LAUNCH(
FBGEMM_LAUNCH_DSA_KERNEL(
ssd_cache_actions_insert_kernel,
div_round_up(N, kMaxThreads / kWarpSize),
dim3(kWarpSize, kMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream(),
MAKE_PTA_WITH_NAME(func_name, lxu_cache_state, int64_t, 2, 32),
MAKE_PTA_WITH_NAME(func_name, sorted_cache_sets, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, cache_set_sorted_unique_indices, int64_t, 1, 32),
PTA_B(lxu_cache_state, int64_t, 2, 32),
PTA_B(sorted_cache_sets, int32_t, 1, 32),
PTA_B(cache_set_sorted_unique_indices, int64_t, 1, 32),
time_stamp,
prefetch_dist,
MAKE_PTA_WITH_NAME(func_name, lru_state, int64_t, 2, 32),
MAKE_PTA_WITH_NAME(func_name, assigned_cache_slots, int32_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, evicted_indices, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, actions_count, int32_t, 1, 32),
PTA_B(lru_state, int64_t, 2, 32),
PTA_B(assigned_cache_slots, int32_t, 1, 32),
PTA_B(evicted_indices, int64_t, 1, 32),
PTA_B(actions_count, int32_t, 1, 32),
lock_cache_line,
MAKE_PTA_WITH_NAME(
func_name, lxu_cache_locking_counter_, int32_t, 2, 32));
PTA_B(lxu_cache_locking_counter_, int32_t, 2, 32));

return std::make_tuple(
cache_set_sorted_unique_indices,
Expand Down
Loading