Skip to content

Commit 9bd0892

Browse files
sryapfacebook-github-bot
authored andcommitted
Support prefetch pipeline in bounds_check_indices (#4312)
Summary: Pull Request resolved: #4312 X-link: facebookresearch/FBGEMM#1174 Frontend of D72343128 This diff reduces the grid dimension of the bounds_check_indices kernel when pipeline prefetching is used (in embedding memory offloading). We need to use the v2 kernel since v1 does not support grid dimension reduction. Reviewed By: jwfromm Differential Revision: D72365505 fbshipit-source-id: 3c57ad3d9c64a0d0f81df70abae1f50b43a1d0fa
1 parent 048376f commit 9bd0892

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1947,6 +1947,7 @@ def forward( # noqa: C901
19471947
per_sample_weights,
19481948
batch_size_per_feature_per_rank,
19491949
force_cast_input_types=True,
1950+
prefetch_pipeline=False,
19501951
)
19511952

19521953
# Print input stats if enable (for debugging purpose only)
@@ -2478,6 +2479,7 @@ def prefetch(
24782479
per_sample_weights=None,
24792480
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
24802481
force_cast_input_types=False,
2482+
prefetch_pipeline=self.prefetch_pipeline,
24812483
)
24822484

24832485
with self._recording_to_timer(
@@ -3543,6 +3545,7 @@ def prepare_inputs(
35433545
per_sample_weights: Optional[Tensor] = None,
35443546
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
35453547
force_cast_input_types: bool = True,
3548+
prefetch_pipeline: bool = False,
35463549
) -> Tuple[Tensor, Tensor, Optional[Tensor], invokers.lookup_args.VBEMetadata]:
35473550
"""
35483551
Prepare TBE inputs as follows:
@@ -3613,9 +3616,17 @@ def prepare_inputs(
36133616
per_sample_weights = per_sample_weights.float()
36143617

36153618
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
3619+
# Override the bounds check version based on prefetch_pipeline
3620+
use_bounds_check_v2 = self.bounds_check_version == 2 or prefetch_pipeline
3621+
bounds_check_version = (
3622+
2 if use_bounds_check_v2 else self.bounds_check_version
3623+
)
3624+
3625+
vbe = vbe_metadata.B_offsets is not None
3626+
36163627
# Compute B info and VBE metadata for bounds_check_indices only if
36173628
# VBE and bounds check indices v2 are used
3618-
if vbe and self.bounds_check_version == 2:
3629+
if vbe and use_bounds_check_v2:
36193630
B_offsets = vbe_metadata.B_offsets
36203631
B_offsets_rank_per_feature = vbe_metadata.B_offsets_rank_per_feature
36213632
output_offsets_feature_rank = vbe_metadata.output_offsets_feature_rank
@@ -3653,7 +3664,8 @@ def prepare_inputs(
36533664
b_t_map=b_t_map,
36543665
info_B_num_bits=self.info_B_num_bits,
36553666
info_B_mask=self.info_B_mask,
3656-
bounds_check_version=self.bounds_check_version,
3667+
bounds_check_version=bounds_check_version,
3668+
prefetch_pipeline=prefetch_pipeline,
36573669
)
36583670

36593671
return indices, offsets, per_sample_weights, vbe_metadata

0 commit comments

Comments
 (0)