Skip to content

Commit 97c191f

Browse files
q10facebook-github-bot
authored andcommitted
Migrate jagged tensor kernels to FBGEMM_LAUNCH_KERNEL, pt 2 (pytorch#4350)
Summary: Pull Request resolved: pytorch#4350 X-link: facebookresearch/FBGEMM#1417 - Migrate jagged tensor kernels to `FBGEMM_LAUNCH_KERNEL`, pt 2 Reviewed By: r-barnes Differential Revision: D74974179 fbshipit-source-id: bc9595118bf8de92cadc79e52b83439ea855516b
1 parent 5b855c1 commit 97c191f

File tree

5 files changed

+52
-75
lines changed

5 files changed

+52
-75
lines changed

fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,16 @@ Tensor batched_dense_vec_jagged_2d_mul_forward(
8080
a_values.scalar_type(),
8181
"dense_vec_jagged_2d_bmm_kernel_2",
8282
[&] {
83-
84-
#ifdef FBGEMM_GPU_MEMCHECK
85-
const auto func_name1 = "dense_vec_jagged_2d_bmm";
86-
#endif
87-
dense_vec_jagged_2d_bmm<index_t, scalar_t>
88-
<<<div_round_up(B * H, block_dim_y),
89-
dim3(block_dim_x, block_dim_y),
90-
0,
91-
at::cuda::getCurrentCUDAStream()>>>(
92-
MAKE_PTA_WITH_NAME(func_name1, v, scalar_t, 2, 32),
93-
MAKE_PTA_WITH_NAME(func_name1, a_values, scalar_t, 2, 32),
94-
MAKE_PTA_WITH_NAME(func_name1, a_offsets, index_t, 1, 32),
95-
MAKE_PTA_WITH_NAME(func_name1, output, scalar_t, 2, 32));
96-
C10_CUDA_KERNEL_LAUNCH_CHECK();
83+
FBGEMM_LAUNCH_KERNEL(
84+
(dense_vec_jagged_2d_bmm<index_t, scalar_t>),
85+
div_round_up(B * H, block_dim_y),
86+
dim3(block_dim_x, block_dim_y),
87+
0,
88+
at::cuda::getCurrentCUDAStream(),
89+
PTA_B(v, scalar_t, 2, 32),
90+
PTA_B(a_values, scalar_t, 2, 32),
91+
PTA_B(a_offsets, index_t, 1, 32),
92+
PTA_B(output, scalar_t, 2, 32));
9793
});
9894
});
9995
}

fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,23 +100,18 @@ Tensor jagged_index_add_2d_forward_cuda(
100100
indices.scalar_type(),
101101
"jagged_index_add_2d_kernel_wrapper_2",
102102
[&] {
103-
#ifdef FBGEMM_GPU_MEMCHECK
104-
const auto func_name = "jagged_index_add_2d_kernel";
105-
#endif
106-
jagged_index_add_2d_kernel<<<
103+
FBGEMM_LAUNCH_KERNEL(
104+
(jagged_index_add_2d_kernel<index_t, int64_t, scalar_t>),
107105
dim3(num_blocks),
108106
dim3(num_cols),
109107
0,
110-
at::cuda::getCurrentCUDAStream()>>>(
111-
MAKE_PTA_WITH_NAME(func_name, output, scalar_t, 2, 64),
112-
MAKE_PTA_WITH_NAME(func_name, values, scalar_t, 2, 64),
113-
MAKE_PTA_WITH_NAME(
114-
func_name, (*input_offsets_contig), int64_t, 1, 32),
115-
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
116-
MAKE_PTA_WITH_NAME(
117-
func_name, output_offsets, int64_t, 1, 32),
108+
at::cuda::getCurrentCUDAStream(),
109+
PTA_B(output, scalar_t, 2, 64),
110+
PTA_B(values, scalar_t, 2, 64),
111+
PTA_B((*input_offsets_contig), int64_t, 1, 32),
112+
PTA_B(indices, index_t, 1, 32),
113+
PTA_B(output_offsets, int64_t, 1, 32),
118114
num_dense_input_rows);
119-
C10_CUDA_KERNEL_LAUNCH_CHECK();
120115
});
121116
});
122117
}

fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,23 +96,18 @@ Tensor jagged_index_select_2d_forward_cuda(
9696
indices.scalar_type(),
9797
"jagged_index_select_2d_kernel_wrapper_2",
9898
[&] {
99-
#ifdef FBGEMM_GPU_MEMCHECK
100-
const auto func_name = "jagged_index_select_2d_kernel";
101-
#endif
102-
jagged_index_select_2d_kernel<<<
99+
FBGEMM_LAUNCH_KERNEL(
100+
(jagged_index_select_2d_kernel<index_t, int64_t, scalar_t>),
103101
dim3(num_blocks),
104102
dim3(num_cols),
105103
0,
106-
at::cuda::getCurrentCUDAStream()>>>(
107-
MAKE_PTA_WITH_NAME(func_name, output, scalar_t, 2, 64),
108-
MAKE_PTA_WITH_NAME(func_name, values, scalar_t, 2, 64),
109-
MAKE_PTA_WITH_NAME(
110-
func_name, input_offsets, int64_t, 1, 32),
111-
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
112-
MAKE_PTA_WITH_NAME(
113-
func_name, (*output_offsets_contig), int64_t, 1, 32),
104+
at::cuda::getCurrentCUDAStream(),
105+
PTA_B(output, scalar_t, 2, 64),
106+
PTA_B(values, scalar_t, 2, 64),
107+
PTA_B(input_offsets, int64_t, 1, 32),
108+
PTA_B(indices, index_t, 1, 32),
109+
PTA_B(*output_offsets_contig, int64_t, 1, 32),
114110
num_dense_output_rows);
115-
C10_CUDA_KERNEL_LAUNCH_CHECK();
116111
});
117112
});
118113
}

fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_backward.cu

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -112,27 +112,20 @@ Tensor jagged_softmax_backward_cuda(
112112
grad_output.scalar_type(),
113113
"jagged_softmax_backward_kernel_2",
114114
[&] {
115-
116-
#ifdef FBGEMM_GPU_MEMCHECK
117-
const auto func_name1 = "jagged_softmax_backward_kernel";
118-
#endif
119-
120-
jagged_softmax_backward_kernel<
115+
FBGEMM_LAUNCH_KERNEL(
116+
(jagged_softmax_backward_kernel<
117+
THREADS_PER_BLOCK,
118+
index_t,
119+
scalar_t>),
120+
grid,
121121
THREADS_PER_BLOCK,
122-
index_t,
123-
scalar_t>
124-
<<<grid,
125-
THREADS_PER_BLOCK,
126-
0,
127-
at::cuda::getCurrentCUDAStream()>>>(
128-
MAKE_PTA_WITH_NAME(
129-
func_name1, grad_output, scalar_t, 2, 32),
130-
MAKE_PTA_WITH_NAME(func_name1, output, scalar_t, 2, 32),
131-
MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32),
132-
MAKE_PTA_WITH_NAME(
133-
func_name1, grad_input, scalar_t, 2, 32),
134-
(int)max_L);
135-
C10_CUDA_KERNEL_LAUNCH_CHECK();
122+
0,
123+
at::cuda::getCurrentCUDAStream(),
124+
PTA_B(grad_output, scalar_t, 2, 32),
125+
PTA_B(output, scalar_t, 2, 32),
126+
PTA_B(offsets, index_t, 1, 32),
127+
PTA_B(grad_input, scalar_t, 2, 32),
128+
(int)max_L);
136129
});
137130
});
138131
}

fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_forward.cu

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,21 +133,19 @@ Tensor jagged_softmax_forward_cuda(
133133
offsets.scalar_type(), "jagged_softmax_kernel_1", [&] {
134134
FBGEMM_DISPATCH_FLOATING_TYPES(
135135
values.scalar_type(), "jagged_softmax_kernel_2", [&] {
136-
137-
#ifdef FBGEMM_GPU_MEMCHECK
138-
const auto func_name1 = "jagged_softmax_kernel";
139-
#endif
140-
141-
jagged_softmax_kernel<THREADS_PER_BLOCK, index_t, scalar_t>
142-
<<<grid,
143-
THREADS_PER_BLOCK,
144-
0,
145-
at::cuda::getCurrentCUDAStream()>>>(
146-
MAKE_PTA_WITH_NAME(func_name1, values, scalar_t, 2, 32),
147-
MAKE_PTA_WITH_NAME(func_name1, offsets, index_t, 1, 32),
148-
MAKE_PTA_WITH_NAME(func_name1, output, scalar_t, 2, 32),
149-
(int)max_L);
150-
C10_CUDA_KERNEL_LAUNCH_CHECK();
136+
FBGEMM_LAUNCH_KERNEL(
137+
(jagged_softmax_kernel<
138+
THREADS_PER_BLOCK,
139+
index_t,
140+
scalar_t>),
141+
grid,
142+
THREADS_PER_BLOCK,
143+
0,
144+
at::cuda::getCurrentCUDAStream(),
145+
PTA_B(values, scalar_t, 2, 32),
146+
PTA_B(offsets, index_t, 1, 32),
147+
PTA_B(output, scalar_t, 2, 32),
148+
(int)max_L);
151149
});
152150
});
153151
}

0 commit comments

Comments
 (0)