Skip to content

Migrate jagged tensor kernels to FBGEMM_LAUNCH_KERNEL, pt 3 #4411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,52 +134,47 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(

const auto threads_bs = dim3(1024, 1, 1);
const auto blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1);
FBGEMM_LAUNCH_KERNEL(
(jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
index_t>),
blocks_bs,
threads_bs,
dynamic_smem_size,
at::cuda::getCurrentCUDAStream(),
PTA_B((x_offsets[0]), index_t, 1, 32),
PTA_B(t_rows_after_bs, int, 1, 32),
PTA_B(t_cols_after_bs, int, 1, 32),
nnz,
B);

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name1 =
"jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_";
#endif
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
index_t>
<<<blocks_bs,
threads_bs,
dynamic_smem_size,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(func_name1, x_offsets[0], index_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, t_rows_after_bs, int, 1, 32),
MAKE_PTA_WITH_NAME(func_name1, t_cols_after_bs, int, 1, 32),
nnz,
B);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// Gather kernel
dim3 threads = dim3(16, 16, 1);
dim3 blocks = dim3(1, div_round_up(nnz, threads.y), 1);
if (blocks.y > 65535) {
blocks.y = 65535;
}
const auto ff = [f] __device__(
__half x, __half y0, __half y1) -> __half {
return f(x, y0, y1);
};

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name2 =
"jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_";
#endif
jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_<
index_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name2, output_values, c10::Half, 2, 32),
MAKE_PTA_WITH_NAME(func_name2, x_values, c10::Half, 2, 32),
MAKE_PTA_WITH_NAME(
func_name2, y_0_reshaped, c10::Half, 3, 32),
MAKE_PTA_WITH_NAME(
func_name2, y_1_reshaped, c10::Half, 3, 32),
MAKE_PTA_WITH_NAME(func_name2, t_rows_after_bs, int, 1, 32),
MAKE_PTA_WITH_NAME(func_name2, t_cols_after_bs, int, 1, 32),
nnz,
E,
[f] __device__(__half x, __half y0, __half y1) -> __half {
return f(x, y0, y1);
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_<
index_t,
decltype(ff)>),
blocks,
threads,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(output_values, c10::Half, 2, 32),
PTA_B(x_values, c10::Half, 2, 32),
PTA_B(y_0_reshaped, c10::Half, 3, 32),
PTA_B(y_1_reshaped, c10::Half, 3, 32),
PTA_B(t_rows_after_bs, int, 1, 32),
PTA_B(t_cols_after_bs, int, 1, 32),
nnz,
E,
ff);
}); // AT_DISPATCH
} else {
JAGGED_TENSOR_DISPATCH_DIMS();
Expand Down
90 changes: 36 additions & 54 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -209,33 +209,24 @@ Tensor jagged_jagged_bmm_forward_cuda(
offsets.scalar_type(), "jagged_jagged_bmm_kernel_1", [&] {
FBGEMM_DISPATCH_FLOATING_TYPES(
x_values.scalar_type(), "jagged_jagged_bmm_kernel_2", [&] {

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name1 = "jagged_jagged_bmm_kernel.1";
#endif

jagged_jagged_bmm_kernel<
BLOCK_TILE_M,
BLOCK_TILE_N,
BLOCK_TILE_K,
THREAD_TILE_M,
THREAD_TILE_N,
index_t,
scalar_t>
<<<grid,
THREADS_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name1, x_values, scalar_t, 2, 32),
MAKE_PTA_WITH_NAME(
func_name1, y_values, scalar_t, 2, 32),
MAKE_PTA_WITH_NAME(
func_name1, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name1, output, scalar_t, 3, 32),
(int)max_L);
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(jagged_jagged_bmm_kernel<
BLOCK_TILE_M,
BLOCK_TILE_N,
BLOCK_TILE_K,
THREAD_TILE_M,
THREAD_TILE_N,
index_t,
scalar_t>),
grid,
THREADS_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(x_values, scalar_t, 2, 32),
PTA_B(y_values, scalar_t, 2, 32),
PTA_B(offsets, index_t, 1, 32),
PTA_B(output, scalar_t, 3, 32),
(int)max_L);
});
});
} else {
Expand Down Expand Up @@ -265,33 +256,24 @@ Tensor jagged_jagged_bmm_forward_cuda(
offsets.scalar_type(), "jagged_jagged_bmm_kernel_1", [&] {
FBGEMM_DISPATCH_FLOATING_TYPES(
x_values.scalar_type(), "jagged_jagged_bmm_kernel_2", [&] {

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name2 = "jagged_jagged_bmm_kernel.2";
#endif

jagged_jagged_bmm_kernel<
BLOCK_TILE_M,
BLOCK_TILE_N,
BLOCK_TILE_K,
THREAD_TILE_M,
THREAD_TILE_N,
index_t,
scalar_t>
<<<grid,
THREADS_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name2, x_values, scalar_t, 2, 32),
MAKE_PTA_WITH_NAME(
func_name2, y_values, scalar_t, 2, 32),
MAKE_PTA_WITH_NAME(
func_name2, offsets, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name2, output, scalar_t, 3, 32),
(int)max_L);
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(jagged_jagged_bmm_kernel<
BLOCK_TILE_M,
BLOCK_TILE_N,
BLOCK_TILE_K,
THREAD_TILE_M,
THREAD_TILE_N,
index_t,
scalar_t>),
grid,
THREADS_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(x_values, scalar_t, 2, 32),
PTA_B(y_values, scalar_t, 2, 32),
PTA_B(offsets, index_t, 1, 32),
PTA_B(output, scalar_t, 3, 32),
(int)max_L);
});
});
}
Expand Down
87 changes: 32 additions & 55 deletions fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,6 @@ class KeyedJaggedIndexSelectDim1GPUOp
block_sums = at::empty({grid_size}, output_offsets.options());
}

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "index_select_scalar_cumsum_wrapper";
#endif

// Do index select and cumsum
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] {
Expand All @@ -214,34 +210,28 @@ class KeyedJaggedIndexSelectDim1GPUOp
indices.scalar_type(),
"index_select_scalar_cumsum_wrapper_3",
[&] {
index_select_scalar_cumsum_kernel<
length_t,
index_t,
offset_t,
FBGEMM_LAUNCH_KERNEL(
(index_select_scalar_cumsum_kernel<
length_t,
index_t,
offset_t,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
MAX_CUMSUM_ENTRIES_PER_BLOCK>),
grid_size,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
MAX_CUMSUM_ENTRIES_PER_BLOCK>
<<<grid_size,
MAX_CUMSUM_ENTRIES_PER_BLOCK,
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name, output_lengths, length_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, output_offsets, offset_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, lengths, length_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, indices, index_t, 1, 32),
num_batches,
batch_size,
num_output_lengths -
MAX_CUMSUM_ENTRIES_PER_BLOCK *
(grid_size - 1),
grid_size > 1 ? block_flags.data_ptr<int>()
: nullptr,
grid_size > 1 ? block_sums.data_ptr<offset_t>()
: nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
0,
at::cuda::getCurrentCUDAStream(),
PTA_B(output_lengths, length_t, 1, 32),
PTA_B(output_offsets, offset_t, 1, 32),
PTA_B(lengths, length_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
num_batches,
batch_size,
num_output_lengths -
MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1),
grid_size > 1 ? block_flags.data_ptr<int>() : nullptr,
grid_size > 1 ? block_sums.data_ptr<offset_t>()
: nullptr);
});
});
});
Expand Down Expand Up @@ -285,9 +275,6 @@ class KeyedJaggedIndexSelectDim1GPUOp
batch_size); \
}

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "keyed_jagged_index_select_dim1";
#endif
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
Expand Down Expand Up @@ -426,10 +413,6 @@ class KeyedJaggedIndexSelectDim1GPUOp
// binary_search_range which takes raw pointers as arguments
const auto grad_offsets_contig = grad_offsets.expect_contiguous();

#ifdef FBGEMM_GPU_MEMCHECK
const auto func_name = "keyed_jagged_index_add_dim1";
#endif

if (grid_size != 0) {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half,
Expand All @@ -446,28 +429,22 @@ class KeyedJaggedIndexSelectDim1GPUOp
indices.scalar_type(),
"keyed_jagged_index_add_dim1_wrapper_3",
[&] {
keyed_jagged_index_add_dim1_kernel<<<
FBGEMM_LAUNCH_KERNEL(
(keyed_jagged_index_add_dim1_kernel<
scalar_t,
index_t,
offset_t>),
grid_size,
kMaxThreads,
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name, grad_input, scalar_t, 1, 64),
MAKE_PTA_WITH_NAME(
func_name, grad, scalar_t, 1, 64),
MAKE_PTA_WITH_NAME(
func_name,
*grad_offsets_contig,
offset_t,
1,
32),
MAKE_PTA_WITH_NAME(
func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, output_offsets, offset_t, 1, 32),
at::cuda::getCurrentCUDAStream(),
PTA_B(grad_input, scalar_t, 1, 64),
PTA_B(grad, scalar_t, 1, 64),
PTA_B(*grad_offsets_contig, offset_t, 1, 32),
PTA_B(indices, index_t, 1, 32),
PTA_B(output_offsets, offset_t, 1, 32),
num_batches,
output_batch_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
Expand Down
Loading