Skip to content

Commit 387e8a0

Browse files
q10facebook-github-bot
authored andcommitted
Migrate jagged tensor kernels to FBGEMM_LAUNCH_KERNEL, pt 3
Summary: X-link: facebookresearch/FBGEMM#1482 - Migrate jagged tensor kernels to `FBGEMM_LAUNCH_KERNEL`, pt 3 Reviewed By: r-barnes Differential Revision: D75104796 fbshipit-source-id: 510b73cf6c57387ff12df93157f359550e807888
1 parent d522d92 commit 387e8a0

File tree

3 files changed

+101
-147
lines changed

3 files changed

+101
-147
lines changed

fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -134,52 +134,47 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
134134

135135
const auto threads_bs = dim3(1024, 1, 1);
136136
const auto blocks_bs = dim3(div_round_up(nnz, threads_bs.x), 1, 1);
137+
FBGEMM_LAUNCH_KERNEL(
138+
(jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
139+
index_t>),
140+
blocks_bs,
141+
threads_bs,
142+
dynamic_smem_size,
143+
at::cuda::getCurrentCUDAStream(),
144+
PTA_B((x_offsets[0]), index_t, 1, 32),
145+
PTA_B(t_rows_after_bs, int, 1, 32),
146+
PTA_B(t_cols_after_bs, int, 1, 32),
147+
nnz,
148+
B);
137149

138-
#ifdef FBGEMM_GPU_MEMCHECK
139-
const auto func_name1 =
140-
"jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_";
141-
#endif
142-
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
143-
index_t>
144-
<<<blocks_bs,
145-
threads_bs,
146-
dynamic_smem_size,
147-
at::cuda::getCurrentCUDAStream()>>>(
148-
MAKE_PTA_WITH_NAME(func_name1, x_offsets[0], index_t, 1, 32),
149-
MAKE_PTA_WITH_NAME(func_name1, t_rows_after_bs, int, 1, 32),
150-
MAKE_PTA_WITH_NAME(func_name1, t_cols_after_bs, int, 1, 32),
151-
nnz,
152-
B);
153-
C10_CUDA_KERNEL_LAUNCH_CHECK();
154150
// Gather kernel
155151
dim3 threads = dim3(16, 16, 1);
156152
dim3 blocks = dim3(1, div_round_up(nnz, threads.y), 1);
157153
if (blocks.y > 65535) {
158154
blocks.y = 65535;
159155
}
156+
const auto ff = [f] __device__(
157+
__half x, __half y0, __half y1) -> __half {
158+
return f(x, y0, y1);
159+
};
160160

161-
#ifdef FBGEMM_GPU_MEMCHECK
162-
const auto func_name2 =
163-
"jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_";
164-
#endif
165-
jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_<
166-
index_t>
167-
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
168-
MAKE_PTA_WITH_NAME(
169-
func_name2, output_values, c10::Half, 2, 32),
170-
MAKE_PTA_WITH_NAME(func_name2, x_values, c10::Half, 2, 32),
171-
MAKE_PTA_WITH_NAME(
172-
func_name2, y_0_reshaped, c10::Half, 3, 32),
173-
MAKE_PTA_WITH_NAME(
174-
func_name2, y_1_reshaped, c10::Half, 3, 32),
175-
MAKE_PTA_WITH_NAME(func_name2, t_rows_after_bs, int, 1, 32),
176-
MAKE_PTA_WITH_NAME(func_name2, t_cols_after_bs, int, 1, 32),
177-
nnz,
178-
E,
179-
[f] __device__(__half x, __half y0, __half y1) -> __half {
180-
return f(x, y0, y1);
181-
});
182-
C10_CUDA_KERNEL_LAUNCH_CHECK();
161+
FBGEMM_LAUNCH_KERNEL(
162+
(jagged_dense_dense_elementwise_jagged_output_opt_gather_kernel_<
163+
index_t,
164+
decltype(ff)>),
165+
blocks,
166+
threads,
167+
0,
168+
at::cuda::getCurrentCUDAStream(),
169+
PTA_B(output_values, c10::Half, 2, 32),
170+
PTA_B(x_values, c10::Half, 2, 32),
171+
PTA_B(y_0_reshaped, c10::Half, 3, 32),
172+
PTA_B(y_1_reshaped, c10::Half, 3, 32),
173+
PTA_B(t_rows_after_bs, int, 1, 32),
174+
PTA_B(t_cols_after_bs, int, 1, 32),
175+
nnz,
176+
E,
177+
ff);
183178
}); // AT_DISPATCH
184179
} else {
185180
JAGGED_TENSOR_DISPATCH_DIMS();

fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu

Lines changed: 36 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -209,33 +209,24 @@ Tensor jagged_jagged_bmm_forward_cuda(
209209
offsets.scalar_type(), "jagged_jagged_bmm_kernel_1", [&] {
210210
FBGEMM_DISPATCH_FLOATING_TYPES(
211211
x_values.scalar_type(), "jagged_jagged_bmm_kernel_2", [&] {
212-
213-
#ifdef FBGEMM_GPU_MEMCHECK
214-
const auto func_name1 = "jagged_jagged_bmm_kernel.1";
215-
#endif
216-
217-
jagged_jagged_bmm_kernel<
218-
BLOCK_TILE_M,
219-
BLOCK_TILE_N,
220-
BLOCK_TILE_K,
221-
THREAD_TILE_M,
222-
THREAD_TILE_N,
223-
index_t,
224-
scalar_t>
225-
<<<grid,
226-
THREADS_PER_BLOCK,
227-
0,
228-
at::cuda::getCurrentCUDAStream()>>>(
229-
MAKE_PTA_WITH_NAME(
230-
func_name1, x_values, scalar_t, 2, 32),
231-
MAKE_PTA_WITH_NAME(
232-
func_name1, y_values, scalar_t, 2, 32),
233-
MAKE_PTA_WITH_NAME(
234-
func_name1, offsets, index_t, 1, 32),
235-
MAKE_PTA_WITH_NAME(
236-
func_name1, output, scalar_t, 3, 32),
237-
(int)max_L);
238-
C10_CUDA_KERNEL_LAUNCH_CHECK();
212+
FBGEMM_LAUNCH_KERNEL(
213+
(jagged_jagged_bmm_kernel<
214+
BLOCK_TILE_M,
215+
BLOCK_TILE_N,
216+
BLOCK_TILE_K,
217+
THREAD_TILE_M,
218+
THREAD_TILE_N,
219+
index_t,
220+
scalar_t>),
221+
grid,
222+
THREADS_PER_BLOCK,
223+
0,
224+
at::cuda::getCurrentCUDAStream(),
225+
PTA_B(x_values, scalar_t, 2, 32),
226+
PTA_B(y_values, scalar_t, 2, 32),
227+
PTA_B(offsets, index_t, 1, 32),
228+
PTA_B(output, scalar_t, 3, 32),
229+
(int)max_L);
239230
});
240231
});
241232
} else {
@@ -265,33 +256,24 @@ Tensor jagged_jagged_bmm_forward_cuda(
265256
offsets.scalar_type(), "jagged_jagged_bmm_kernel_1", [&] {
266257
FBGEMM_DISPATCH_FLOATING_TYPES(
267258
x_values.scalar_type(), "jagged_jagged_bmm_kernel_2", [&] {
268-
269-
#ifdef FBGEMM_GPU_MEMCHECK
270-
const auto func_name2 = "jagged_jagged_bmm_kernel.2";
271-
#endif
272-
273-
jagged_jagged_bmm_kernel<
274-
BLOCK_TILE_M,
275-
BLOCK_TILE_N,
276-
BLOCK_TILE_K,
277-
THREAD_TILE_M,
278-
THREAD_TILE_N,
279-
index_t,
280-
scalar_t>
281-
<<<grid,
282-
THREADS_PER_BLOCK,
283-
0,
284-
at::cuda::getCurrentCUDAStream()>>>(
285-
MAKE_PTA_WITH_NAME(
286-
func_name2, x_values, scalar_t, 2, 32),
287-
MAKE_PTA_WITH_NAME(
288-
func_name2, y_values, scalar_t, 2, 32),
289-
MAKE_PTA_WITH_NAME(
290-
func_name2, offsets, index_t, 1, 32),
291-
MAKE_PTA_WITH_NAME(
292-
func_name2, output, scalar_t, 3, 32),
293-
(int)max_L);
294-
C10_CUDA_KERNEL_LAUNCH_CHECK();
259+
FBGEMM_LAUNCH_KERNEL(
260+
(jagged_jagged_bmm_kernel<
261+
BLOCK_TILE_M,
262+
BLOCK_TILE_N,
263+
BLOCK_TILE_K,
264+
THREAD_TILE_M,
265+
THREAD_TILE_N,
266+
index_t,
267+
scalar_t>),
268+
grid,
269+
THREADS_PER_BLOCK,
270+
0,
271+
at::cuda::getCurrentCUDAStream(),
272+
PTA_B(x_values, scalar_t, 2, 32),
273+
PTA_B(y_values, scalar_t, 2, 32),
274+
PTA_B(offsets, index_t, 1, 32),
275+
PTA_B(output, scalar_t, 3, 32),
276+
(int)max_L);
295277
});
296278
});
297279
}

fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu

Lines changed: 32 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,6 @@ class KeyedJaggedIndexSelectDim1GPUOp
197197
block_sums = at::empty({grid_size}, output_offsets.options());
198198
}
199199

200-
#ifdef FBGEMM_GPU_MEMCHECK
201-
const auto func_name = "index_select_scalar_cumsum_wrapper";
202-
#endif
203-
204200
// Do index select and cumsum
205201
AT_DISPATCH_INDEX_TYPES(
206202
lengths.scalar_type(), "index_select_scalar_cumsum_wrapper_1", [&] {
@@ -214,34 +210,28 @@ class KeyedJaggedIndexSelectDim1GPUOp
214210
indices.scalar_type(),
215211
"index_select_scalar_cumsum_wrapper_3",
216212
[&] {
217-
index_select_scalar_cumsum_kernel<
218-
length_t,
219-
index_t,
220-
offset_t,
213+
FBGEMM_LAUNCH_KERNEL(
214+
(index_select_scalar_cumsum_kernel<
215+
length_t,
216+
index_t,
217+
offset_t,
218+
MAX_CUMSUM_ENTRIES_PER_BLOCK,
219+
MAX_CUMSUM_ENTRIES_PER_BLOCK>),
220+
grid_size,
221221
MAX_CUMSUM_ENTRIES_PER_BLOCK,
222-
MAX_CUMSUM_ENTRIES_PER_BLOCK>
223-
<<<grid_size,
224-
MAX_CUMSUM_ENTRIES_PER_BLOCK,
225-
0,
226-
at::cuda::getCurrentCUDAStream()>>>(
227-
MAKE_PTA_WITH_NAME(
228-
func_name, output_lengths, length_t, 1, 32),
229-
MAKE_PTA_WITH_NAME(
230-
func_name, output_offsets, offset_t, 1, 32),
231-
MAKE_PTA_WITH_NAME(
232-
func_name, lengths, length_t, 1, 32),
233-
MAKE_PTA_WITH_NAME(
234-
func_name, indices, index_t, 1, 32),
235-
num_batches,
236-
batch_size,
237-
num_output_lengths -
238-
MAX_CUMSUM_ENTRIES_PER_BLOCK *
239-
(grid_size - 1),
240-
grid_size > 1 ? block_flags.data_ptr<int>()
241-
: nullptr,
242-
grid_size > 1 ? block_sums.data_ptr<offset_t>()
243-
: nullptr);
244-
C10_CUDA_KERNEL_LAUNCH_CHECK();
222+
0,
223+
at::cuda::getCurrentCUDAStream(),
224+
PTA_B(output_lengths, length_t, 1, 32),
225+
PTA_B(output_offsets, offset_t, 1, 32),
226+
PTA_B(lengths, length_t, 1, 32),
227+
PTA_B(indices, index_t, 1, 32),
228+
num_batches,
229+
batch_size,
230+
num_output_lengths -
231+
MAX_CUMSUM_ENTRIES_PER_BLOCK * (grid_size - 1),
232+
grid_size > 1 ? block_flags.data_ptr<int>() : nullptr,
233+
grid_size > 1 ? block_sums.data_ptr<offset_t>()
234+
: nullptr);
245235
});
246236
});
247237
});
@@ -285,9 +275,6 @@ class KeyedJaggedIndexSelectDim1GPUOp
285275
batch_size); \
286276
}
287277

288-
#ifdef FBGEMM_GPU_MEMCHECK
289-
const auto func_name = "keyed_jagged_index_select_dim1";
290-
#endif
291278
AT_DISPATCH_ALL_TYPES_AND2(
292279
at::ScalarType::Half,
293280
at::ScalarType::BFloat16,
@@ -426,10 +413,6 @@ class KeyedJaggedIndexSelectDim1GPUOp
426413
// binary_search_range which takes raw pointers as arguments
427414
const auto grad_offsets_contig = grad_offsets.expect_contiguous();
428415

429-
#ifdef FBGEMM_GPU_MEMCHECK
430-
const auto func_name = "keyed_jagged_index_add_dim1";
431-
#endif
432-
433416
if (grid_size != 0) {
434417
AT_DISPATCH_ALL_TYPES_AND2(
435418
at::ScalarType::Half,
@@ -446,28 +429,22 @@ class KeyedJaggedIndexSelectDim1GPUOp
446429
indices.scalar_type(),
447430
"keyed_jagged_index_add_dim1_wrapper_3",
448431
[&] {
449-
keyed_jagged_index_add_dim1_kernel<<<
432+
FBGEMM_LAUNCH_KERNEL(
433+
(keyed_jagged_index_add_dim1_kernel<
434+
scalar_t,
435+
index_t,
436+
offset_t>),
450437
grid_size,
451438
kMaxThreads,
452439
0,
453-
at::cuda::getCurrentCUDAStream()>>>(
454-
MAKE_PTA_WITH_NAME(
455-
func_name, grad_input, scalar_t, 1, 64),
456-
MAKE_PTA_WITH_NAME(
457-
func_name, grad, scalar_t, 1, 64),
458-
MAKE_PTA_WITH_NAME(
459-
func_name,
460-
*grad_offsets_contig,
461-
offset_t,
462-
1,
463-
32),
464-
MAKE_PTA_WITH_NAME(
465-
func_name, indices, index_t, 1, 32),
466-
MAKE_PTA_WITH_NAME(
467-
func_name, output_offsets, offset_t, 1, 32),
440+
at::cuda::getCurrentCUDAStream(),
441+
PTA_B(grad_input, scalar_t, 1, 64),
442+
PTA_B(grad, scalar_t, 1, 64),
443+
PTA_B(*grad_offsets_contig, offset_t, 1, 32),
444+
PTA_B(indices, index_t, 1, 32),
445+
PTA_B(output_offsets, offset_t, 1, 32),
468446
num_batches,
469447
output_batch_size);
470-
C10_CUDA_KERNEL_LAUNCH_CHECK();
471448
});
472449
});
473450
});

0 commit comments

Comments
 (0)