Skip to content

Commit beaf74c

Browse files
generatedunixname89002005232357facebook-github-bot
authored andcommitted
Revert D76997485 (#3183)
Summary: Pull Request resolved: #3183 This diff reverts D76997485 https://www.internalfb.com/intern/test/562950177386883?ref_report_id=0 https://www.internalfb.com/intern/test/562950177237448?ref_report_id=0 https://www.internalfb.com/intern/test/844425154071675?ref_report_id=0 and more Depends on D76997485 Reviewed By: PoojaAg18 Differential Revision: D78147010 fbshipit-source-id: 4b7ff779f8dbe509528b69e9020c775888818809
1 parent 5cbc277 commit beaf74c

File tree

2 files changed

+0
-18
lines changed

2 files changed

+0
-18
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,18 +1095,13 @@ def _maybe_compute_stride_kjt(
10951095
lengths: Optional[torch.Tensor],
10961096
offsets: Optional[torch.Tensor],
10971097
stride_per_key_per_rank: Optional[torch.IntTensor],
1098-
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
10991098
) -> int:
11001099
if stride is None:
11011100
if len(keys) == 0:
11021101
stride = 0
11031102
elif (
11041103
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
11051104
):
1106-
# For VBE KJT, batch size should be based on inverse_indices when set.
1107-
if inverse_indices is not None:
1108-
return inverse_indices[1].shape[-1]
1109-
11101105
s = stride_per_key_per_rank.sum(dim=1).max().item()
11111106
if not torch.jit.is_scripting() and is_non_strict_exporting():
11121107
stride = torch.sym_int(s)
@@ -2151,7 +2146,6 @@ def stride(self) -> int:
21512146
self._lengths,
21522147
self._offsets,
21532148
self._stride_per_key_per_rank,
2154-
self._inverse_indices,
21552149
)
21562150
self._stride = stride
21572151
return stride

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,18 +1017,6 @@ def test_meta_device_compatibility(self) -> None:
10171017
lengths=torch.tensor([], device=torch.device("meta")),
10181018
)
10191019

1020-
def test_vbe_kjt_stride(self) -> None:
1021-
inverse_indices = torch.tensor([[0, 1, 0], [0, 0, 0]])
1022-
kjt = KeyedJaggedTensor(
1023-
keys=["f1", "f2", "f3"],
1024-
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
1025-
lengths=torch.tensor([3, 3, 2]),
1026-
stride_per_key_per_rank=[[2], [1]],
1027-
inverse_indices=(["f1", "f2"], inverse_indices),
1028-
)
1029-
1030-
self.assertEqual(kjt.stride(), inverse_indices.shape[-1])
1031-
10321020

10331021
class TestKeyedJaggedTensorScripting(unittest.TestCase):
10341022
def test_scriptable_forward(self) -> None:

0 commit comments

Comments
 (0)