Skip to content

Commit 8ba5184

Browse files
ghq24intfacebook-github-bot
authored andcommitted
add optimized reorder_batched_ad_indices_kernel on AMD (#4388)
Summary: X-link: facebookresearch/FBGEMM#1459 Pull Request resolved: #4388 reorder_batched_ad_indices_kernel_dtypeLong() for dtype Long reorder_batched_ad_indices_kernel_dtypeFloat() for dtype float itype makes no difference on performance. Performance results on CFR & IFR model can be reproduced through the benchmark in D77066925 Reviewed By: jianyuh Differential Revision: D77078971 fbshipit-source-id: 8c92f032d3125ec04e23385c83b799d1f1dd2ea4
1 parent 0432615 commit 8ba5184

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

fbgemm_gpu/src/sparse_ops/sparse_reorder_batched_ad.cu

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,102 @@ __launch_bounds__(kMaxThreads) void reorder_batched_ad_indices_kernel(
240240
}
241241
}
242242

243+
template <typename Dtype, typename index_t = int32_t>
244+
__global__
245+
__launch_bounds__(fbgemm_gpu::kMaxThreads) void reorder_batched_ad_indices_kernel_vec(
246+
// reorder indices from (ragged) [B x T x #num_ads_b x length_{b, t, a})]
247+
// to [T][B][#num_ads_b][length_{b, t, a}], i.e. [sum(length_{b, t, a})],
248+
// laid out as [T][B][A][L] (if all lengths were equal).
249+
250+
// if broadcast_indices is enabled, all the indices will be copies of the
251+
// first batch of the cat_ad_indices, this is useful for request-only
252+
// broadcast
253+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
254+
cat_ad_offsets,
255+
const pta::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
256+
cat_ad_indices,
257+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
258+
reordered_cat_ad_offsets,
259+
pta::PackedTensorAccessor32<Dtype, 1, at::RestrictPtrTraits>
260+
reordered_cat_ad_indices,
261+
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits>
262+
batch_offsets,
263+
const int32_t T,
264+
const bool broadcast_indices) {
265+
using vec2_t =
266+
typename std::conditional<sizeof(Dtype) == 8, long2, float2>::type;
267+
using vec4_t =
268+
typename std::conditional<sizeof(Dtype) == 8, long4, float4>::type;
269+
const int32_t B = batch_offsets.size(0) - 1;
270+
const int32_t num_ads_in_batch = batch_offsets[B];
271+
// warp-per-segment.
272+
const auto b_t = blockIdx.x * blockDim.y +
273+
threadIdx.y; // can be more efficient through bitwise op
274+
const int32_t b = b_t % B;
275+
const int32_t t = b_t / B;
276+
if (t >= T) {
277+
return;
278+
}
279+
280+
const auto num_ads_b = batch_offsets[b + 1] - batch_offsets[b];
281+
const auto output_segment_offset_start =
282+
t * num_ads_in_batch + batch_offsets[b];
283+
const auto output_segment_start =
284+
reordered_cat_ad_offsets[output_segment_offset_start];
285+
const int32_t input_segment_offset_start =
286+
broadcast_indices ? T * b + t : T * batch_offsets[b] + t * num_ads_b;
287+
const int32_t input_segment_offset_end = broadcast_indices
288+
? input_segment_offset_start + 1
289+
: input_segment_offset_start + num_ads_b;
290+
const auto input_segment_start = cat_ad_offsets[input_segment_offset_start];
291+
const auto input_segment_end = cat_ad_offsets[input_segment_offset_end];
292+
const auto num_elements = input_segment_end - input_segment_start;
293+
294+
if (broadcast_indices) {
295+
for (auto i = threadIdx.x; i < num_ads_b * num_elements; i += blockDim.x) {
296+
reordered_cat_ad_indices[output_segment_start + i] =
297+
cat_ad_indices[input_segment_start + i % num_elements];
298+
}
299+
} else {
300+
// Idea: we want to copy the entire segment of size sum_a(length_{b, t, a})
301+
// from starting point (given by cat_ad_offsets[b, t])
302+
// to end point (given by reordered_cat_ad_indices[t][b])
303+
if (num_elements <= 64) {
304+
for (auto i = threadIdx.x; i < input_segment_end - input_segment_start;
305+
i += blockDim.x) {
306+
// coalesced global memory access, can be optimzed through ILP with the
307+
// help of shared memory or vector load/store (if num_ads_b>=64)
308+
reordered_cat_ad_indices[output_segment_start + i] =
309+
cat_ad_indices[input_segment_start + i];
310+
}
311+
} else if (num_elements > 64 && num_elements <= 128) {
312+
auto dst =
313+
(vec2_t*)(reordered_cat_ad_indices.data() + output_segment_start);
314+
auto src = (vec2_t*)(cat_ad_indices.data() + input_segment_start);
315+
for (auto i = threadIdx.x; i < num_elements / 2; i += blockDim.x) {
316+
dst[i] = src[i];
317+
}
318+
if ((num_elements % 2) && threadIdx.x == 31) {
319+
reordered_cat_ad_indices[output_segment_start + num_elements - 1] =
320+
cat_ad_indices[input_segment_start + num_elements - 1];
321+
}
322+
} else if (num_elements > 128) {
323+
auto dst =
324+
(vec4_t*)(reordered_cat_ad_indices.data() + output_segment_start);
325+
auto src = (vec4_t*)(cat_ad_indices.data() + input_segment_start);
326+
for (auto i = threadIdx.x; i < num_elements / 4; i += blockDim.x) {
327+
dst[i] = src[i];
328+
}
329+
int remainder = num_elements % 4;
330+
if (remainder && threadIdx.x < remainder) {
331+
reordered_cat_ad_indices
332+
[output_segment_start + num_elements - threadIdx.x - 1] =
333+
cat_ad_indices
334+
[input_segment_start + num_elements - threadIdx.x - 1];
335+
}
336+
}
337+
}
338+
}
243339
DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
244340
const Tensor& cat_ad_offsets,
245341
const Tensor& cat_ad_indices,
@@ -387,6 +483,7 @@ DLL_PUBLIC Tensor reorder_batched_ad_indices_gpu(
387483
C10_CUDA_KERNEL_LAUNCH_CHECK();
388484
});
389485
});
486+
390487
return reordered_cat_ad_indices;
391488
}
392489

0 commit comments

Comments
 (0)