diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu index 7cb2d47882..1e20ee4b40 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu @@ -134,52 +134,47 @@ void jagged_dense_dense_elementwise_jagged_output_opt_( const auto threads_bs = dim3(1024, 1, 1); const auto blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1); + FBGEMM_LAUNCH_KERNEL( + (jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< + index_t>), + blocks_bs, + threads_bs, + dynamic_smem_size, + at::cuda::getCurrentCUDAStream(), + PTA_B((x_offsets[0]), index_t, 1, 32), + PTA_B(t_rows_after_bs, int, 1, 32), + PTA_B(t_cols_after_bs, int, 1, 32), + nnz, + B); -#ifdef FBGEMM_GPU_MEMCHECK - const auto func_name1 = - "jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_"; -#endif - jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_< - index_t> - <<>>( - MAKE_PTA_WITH_NAME(func_name1, x_offsets[0], index_t, 1, 32), - MAKE_PTA_WITH_NAME(func_name1, t_rows_after_bs, int, 1, 32), - MAKE_PTA_WITH_NAME(func_name1, t_cols_after_bs, int, 1, 32), - nnz, - B); - C10_CUDA_KERNEL_LAUNCH_CHECK(); // Gather kernel dim3 threads = dim3(16, 16, 1); dim3 blocks = dim3(1, div_round_up(nnz, threads.y), 1); if (blocks.y > 65535) { blocks.y = 65535; } + const auto ff = [f] __device__( + __half x, __half y0, __half y1) -> __half { + return f(x, y0, y1); + }; -#ifdef FBGEMM_GPU_MEMCHECK - const auto func_name2 = - "jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_"; -#endif - jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_< - index_t> - <<>>( - MAKE_PTA_WITH_NAME( - func_name2, output_values, c10::Half, 2, 32), - MAKE_PTA_WITH_NAME(func_name2, x_values, c10::Half, 2, 32), - MAKE_PTA_WITH_NAME( - func_name2, y_0_reshaped, c10::Half, 3, 32), - MAKE_PTA_WITH_NAME( - func_name2, y_1_reshaped, c10::Half, 3, 32), - MAKE_PTA_WITH_NAME(func_name2, t_rows_after_bs, int, 1, 32), - MAKE_PTA_WITH_NAME(func_name2, t_cols_after_bs, int, 1, 32), - nnz, - E, - [f] __device__(__half x, __half y0, __half y1) -> __half { - return f(x, y0, y1); - }); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_< + index_t, + decltype(ff)>), + blocks, + threads, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(output_values, c10::Half, 2, 32), + PTA_B(x_values, c10::Half, 2, 32), + PTA_B(y_0_reshaped, c10::Half, 3, 32), + PTA_B(y_1_reshaped, c10::Half, 3, 32), + PTA_B(t_rows_after_bs, int, 1, 32), + PTA_B(t_cols_after_bs, int, 1, 32), + nnz, + E, + ff); }); // AT_DISPATCH } else { JAGGED_TENSOR_DISPATCH_DIMS(); diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu b/fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu index 8fbd361a79..c8aa18260a 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu @@ -209,33 +209,24 @@ Tensor jagged_jagged_bmm_forward_cuda( offsets.scalar_type(), "jagged_jagged_bmm_kernel_1", [&] { FBGEMM_DISPATCH_FLOATING_TYPES( x_values.scalar_type(), "jagged_jagged_bmm_kernel_2", [&] { - -#ifdef FBGEMM_GPU_MEMCHECK - const auto func_name1 = "jagged_jagged_bmm_kernel.1"; -#endif - - jagged_jagged_bmm_kernel< - BLOCK_TILE_M, - BLOCK_TILE_N, - BLOCK_TILE_K, - THREAD_TILE_M, - THREAD_TILE_N, - index_t, - scalar_t> - <<>>( - MAKE_PTA_WITH_NAME( - func_name1, x_values, scalar_t, 2, 32), - MAKE_PTA_WITH_NAME( - func_name1, y_values, scalar_t, 2, 32), - MAKE_PTA_WITH_NAME( - func_name1, offsets, index_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name1, output, scalar_t, 3, 32), - (int)max_L); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (jagged_jagged_bmm_kernel< + BLOCK_TILE_M, + BLOCK_TILE_N, + BLOCK_TILE_K, + THREAD_TILE_M, + THREAD_TILE_N, + index_t, + scalar_t>), + grid, + THREADS_PER_BLOCK, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(x_values, scalar_t, 2, 32), + PTA_B(y_values, scalar_t, 2, 32), + PTA_B(offsets, index_t, 1, 32), + PTA_B(output, scalar_t, 3, 32), + (int)max_L); }); }); } else { @@ -265,33 +256,24 @@ Tensor jagged_jagged_bmm_forward_cuda( offsets.scalar_type(), "jagged_jagged_bmm_kernel_1", [&] { FBGEMM_DISPATCH_FLOATING_TYPES( x_values.scalar_type(), "jagged_jagged_bmm_kernel_2", [&] { - -#ifdef FBGEMM_GPU_MEMCHECK - const auto func_name2 = "jagged_jagged_bmm_kernel.2"; -#endif - - jagged_jagged_bmm_kernel< - BLOCK_TILE_M, - BLOCK_TILE_N, - BLOCK_TILE_K, - THREAD_TILE_M, - THREAD_TILE_N, - index_t, - scalar_t> - <<>>( - MAKE_PTA_WITH_NAME( - func_name2, x_values, scalar_t, 2, 32), - MAKE_PTA_WITH_NAME( - func_name2, y_values, scalar_t, 2, 32), - MAKE_PTA_WITH_NAME( - func_name2, offsets, index_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name2, output, scalar_t, 3, 32), - (int)max_L); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (jagged_jagged_bmm_kernel< + BLOCK_TILE_M, + BLOCK_TILE_N, + BLOCK_TILE_K, + THREAD_TILE_M, + THREAD_TILE_N, + index_t, + scalar_t>), + grid, + THREADS_PER_BLOCK, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(x_values, scalar_t, 2, 32), + PTA_B(y_values, scalar_t, 2, 32), + PTA_B(offsets, index_t, 1, 32), + PTA_B(output, scalar_t, 3, 32), + (int)max_L); }); }); } diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index fd98488b78..bd71d531cb 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -197,10 +197,6 @@ class KeyedJaggedIndexSelectDim1GPUOp block_sums = at::empty({grid_size}, output_offsets.options()); } -#ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = "index_select_scalar_cumsum_wrapper"; -#endif - // Do index select and cumsum AT_DISPATCH_INDEX_TYPES( lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] { @@ -214,34 +210,28 @@ class KeyedJaggedIndexSelectDim1GPUOp indices.scalar_type(), "index_select_scalar_cumsum_wrapper_3", [&] { - index_select_scalar_cumsum_kernel< - length_t, - index_t, - offset_t, + FBGEMM_LAUNCH_KERNEL( + (index_select_scalar_cumsum_kernel< + length_t, + index_t, + offset_t, + MAX_CUMSUM_ENTRIES_PER_BLOCK, + MAX_CUMSUM_ENTRIES_PER_BLOCK>), + grid_size, MAX_CUMSUM_ENTRIES_PER_BLOCK, - MAX_CUMSUM_ENTRIES_PER_BLOCK> - <<>>( - MAKE_PTA_WITH_NAME( - func_name, output_lengths, length_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, output_offsets, offset_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, lengths, length_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, indices, index_t, 1, 32), - num_batches, - batch_size, - num_output_lengths - - MAX_CUMSUM_ENTRIES_PER_BLOCK * - (grid_size - 1), - grid_size > 1 ? block_flags.data_ptr() - : nullptr, - grid_size > 1 ? block_sums.data_ptr() - : nullptr); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(output_lengths, length_t, 1, 32), + PTA_B(output_offsets, offset_t, 1, 32), + PTA_B(lengths, length_t, 1, 32), + PTA_B(indices, index_t, 1, 32), + num_batches, + batch_size, + num_output_lengths - + MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1), + grid_size > 1 ? block_flags.data_ptr() : nullptr, + grid_size > 1 ? block_sums.data_ptr() + : nullptr); }); }); }); @@ -285,9 +275,6 @@ class KeyedJaggedIndexSelectDim1GPUOp batch_size); \ } -#ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = "keyed_jagged_index_select_dim1"; -#endif AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, @@ -426,10 +413,6 @@ class KeyedJaggedIndexSelectDim1GPUOp // binary_search_range which takes raw pointers as arguments const auto grad_offsets_contig = grad_offsets.expect_contiguous(); -#ifdef FBGEMM_GPU_MEMCHECK - const auto func_name = "keyed_jagged_index_add_dim1"; -#endif - if (grid_size != 0) { AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, @@ -446,28 +429,22 @@ class KeyedJaggedIndexSelectDim1GPUOp indices.scalar_type(), "keyed_jagged_index_add_dim1_wrapper_3", [&] { - keyed_jagged_index_add_dim1_kernel<<< + FBGEMM_LAUNCH_KERNEL( + (keyed_jagged_index_add_dim1_kernel< + scalar_t, + index_t, + offset_t>), grid_size, kMaxThreads, 0, - at::cuda::getCurrentCUDAStream()>>>( - MAKE_PTA_WITH_NAME( - func_name, grad_input, scalar_t, 1, 64), - MAKE_PTA_WITH_NAME( - func_name, grad, scalar_t, 1, 64), - MAKE_PTA_WITH_NAME( - func_name, - *grad_offsets_contig, - offset_t, - 1, - 32), - MAKE_PTA_WITH_NAME( - func_name, indices, index_t, 1, 32), - MAKE_PTA_WITH_NAME( - func_name, output_offsets, offset_t, 1, 32), + at::cuda::getCurrentCUDAStream(), + PTA_B(grad_input, scalar_t, 1, 64), + PTA_B(grad, scalar_t, 1, 64), + PTA_B(*grad_offsets_contig, offset_t, 1, 32), + PTA_B(indices, index_t, 1, 32), + PTA_B(output_offsets, offset_t, 1, 32), num_batches, output_batch_size); - C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); });