Skip to content

Commit d797031

Browse files
generatedunixname89002005232357facebook-github-bot
authored andcommitted
Revert D74295924 (#3184)
Summary: Pull Request resolved: #3184 This diff reverts D74295924 Depends on D78147010 https://www.internalfb.com/intern/test/562950177386883?ref_report_id=0 https://www.internalfb.com/intern/test/562950177237448?ref_report_id=0 https://www.internalfb.com/intern/test/844425154071675?ref_report_id=0 and more Depends on D74295924 Reviewed By: PoojaAg18 Differential Revision: D78147034 fbshipit-source-id: b1597ffe66639563cf4a13f7b07d1c127d0d7149
1 parent beaf74c commit d797031

File tree

2 files changed

+39
-79
lines changed

2 files changed

+39
-79
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -207,14 +207,8 @@ def forward(
207207
num_embeddings=10,
208208
feature_names=["f2"],
209209
)
210-
config3 = EmbeddingBagConfig(
211-
name="t3",
212-
embedding_dim=5,
213-
num_embeddings=10,
214-
feature_names=["f3"],
215-
)
216210
ebc = EmbeddingBagCollection(
217-
tables=[config1, config2, config3],
211+
tables=[config1, config2],
218212
is_weighted=False,
219213
)
220214

@@ -299,60 +293,42 @@ def test_serialize_deserialize_ebc(self) -> None:
299293
self.assertEqual(deserialized.shape, orginal.shape)
300294
self.assertTrue(torch.allclose(deserialized, orginal))
301295

296+
@unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.")
302297
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
303298
model = self.generate_model_for_vbe_kjt()
304-
kjt_1 = KeyedJaggedTensor(
305-
keys=["f1", "f2", "f3"],
306-
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
307-
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
308-
stride_per_key_per_rank=torch.tensor([[3], [2], [1]]),
309-
inverse_indices=(
310-
["f1", "f2", "f3"],
311-
torch.tensor([[0, 1, 2], [0, 1, 0], [0, 0, 0]]),
312-
),
313-
)
314-
kjt_2 = KeyedJaggedTensor(
315-
keys=["f1", "f2", "f3"],
316-
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
317-
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
318-
stride_per_key_per_rank=torch.tensor([[1], [2], [3]]),
319-
inverse_indices=(
320-
["f1", "f2", "f3"],
321-
torch.tensor([[0, 0, 0], [0, 1, 0], [0, 1, 2]]),
322-
),
299+
id_list_features = KeyedJaggedTensor(
300+
keys=["f1", "f2"],
301+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
302+
lengths=torch.tensor([3, 3, 2]),
303+
stride_per_key_per_rank=[[2], [1]],
304+
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
323305
)
324306

325-
eager_out = model(kjt_1)
326-
eager_out_2 = model(kjt_2)
307+
eager_out = model(id_list_features)
327308

328309
# Serialize EBC
329310
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
330311
ep = torch.export.export(
331312
model,
332-
(kjt_1,),
313+
(id_list_features,),
333314
{},
334315
strict=False,
335316
# Allows KJT to not be unflattened and run a forward on unflattened EP
336317
preserve_module_call_signature=(tuple(sparse_fqns)),
337318
)
338319

339320
# Run forward on ExportedProgram
340-
ep_output = ep.module()(kjt_1)
341-
ep_output_2 = ep.module()(kjt_2)
321+
ep_output = ep.module()(id_list_features)
342322

343-
self.assertEqual(len(ep_output), len(kjt_1.keys()))
344-
self.assertEqual(len(ep_output_2), len(kjt_2.keys()))
345323
for i, tensor in enumerate(ep_output):
346-
self.assertEqual(eager_out[i].shape[1], tensor.shape[1])
347-
for i, tensor in enumerate(ep_output_2):
348-
self.assertEqual(eager_out_2[i].shape[1], tensor.shape[1])
324+
self.assertEqual(eager_out[i].shape, tensor.shape)
349325

350326
# Deserialize EBC
351327
unflatten_ep = torch.export.unflatten(ep)
352328
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
353329

354330
# check EBC config
355-
for i in range(1):
331+
for i in range(5):
356332
ebc_name = f"ebc{i + 1}"
357333
self.assertIsInstance(
358334
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
@@ -367,22 +343,36 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
367343
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
368344
self.assertEqual(deserialized.feature_names, orginal.feature_names)
369345

346+
# check FPEBC config
347+
for i in range(2):
348+
fpebc_name = f"fpebc{i + 1}"
349+
assert isinstance(
350+
getattr(deserialized_model, fpebc_name),
351+
FeatureProcessedEmbeddingBagCollection,
352+
)
353+
354+
for deserialized, orginal in zip(
355+
getattr(
356+
deserialized_model, fpebc_name
357+
)._embedding_bag_collection.embedding_bag_configs(),
358+
getattr(
359+
model, fpebc_name
360+
)._embedding_bag_collection.embedding_bag_configs(),
361+
):
362+
self.assertEqual(deserialized.name, orginal.name)
363+
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
364+
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
365+
self.assertEqual(deserialized.feature_names, orginal.feature_names)
366+
370367
# Run forward on deserialized model and compare the output
371368
deserialized_model.load_state_dict(model.state_dict())
372-
deserialized_out = deserialized_model(kjt_1)
369+
deserialized_out = deserialized_model(id_list_features)
373370

374371
self.assertEqual(len(deserialized_out), len(eager_out))
375372
for deserialized, orginal in zip(deserialized_out, eager_out):
376373
self.assertEqual(deserialized.shape, orginal.shape)
377374
self.assertTrue(torch.allclose(deserialized, orginal))
378375

379-
deserialized_out_2 = deserialized_model(kjt_2)
380-
381-
self.assertEqual(len(deserialized_out_2), len(eager_out_2))
382-
for deserialized, orginal in zip(deserialized_out_2, eager_out_2):
383-
self.assertEqual(deserialized.shape, orginal.shape)
384-
self.assertTrue(torch.allclose(deserialized, orginal))
385-
386376
def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
387377
model = self.generate_model()
388378
feature1 = KeyedJaggedTensor.from_offsets_sync(

torchrec/sparse/jagged_tensor.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,8 +1728,6 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
17281728
"_weights",
17291729
"_lengths",
17301730
"_offsets",
1731-
"_stride_per_key_per_rank",
1732-
"_inverse_indices",
17331731
]
17341732

17351733
def __init__(
@@ -3018,26 +3016,7 @@ def dist_init(
30183016
def _kjt_flatten(
30193017
t: KeyedJaggedTensor,
30203018
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
3021-
"""
3022-
Used by PyTorch's pytree utilities for serialization and processing.
3023-
Extracts tensor attributes of a KeyedJaggedTensor and returns them
3024-
as a flat list, along with the necessary metadata to reconstruct the KeyedJaggedTensor.
3025-
3026-
Component tensors are returned as dynamic attributes.
3027-
KJT metadata are added as static specs.
3028-
3029-
Returns:
3030-
Tuple containing:
3031-
- List[Optional[torch.Tensor]]: All tensor attributes (_values, _weights, _lengths,
3032-
_offsets, _stride_per_key_per_rank, and the tensor part of _inverse_indices if present)
3033-
- Tuple[List[str], List[str]]: Metadata needed for reconstruction:
3034-
- List of keys from the original KeyedJaggedTensor
3035-
- List of inverse indices keys (if present, otherwise empty list)
3036-
"""
3037-
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
3038-
values.append(t._inverse_indices[1] if t._inverse_indices is not None else None)
3039-
3040-
return values, t._keys
3019+
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
30413020

30423021

30433022
def _kjt_flatten_with_keys(
@@ -3051,24 +3030,15 @@ def _kjt_flatten_with_keys(
30513030

30523031

30533032
def _kjt_unflatten(
3054-
values: List[Optional[torch.Tensor]],
3055-
context: List[str], # context is _keys
3033+
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
30563034
) -> KeyedJaggedTensor:
3057-
return KeyedJaggedTensor(
3058-
context,
3059-
*values[:-2],
3060-
stride_per_key_per_rank=values[-2],
3061-
inverse_indices=(context, values[-1]) if values[-1] is not None else None,
3062-
)
3035+
return KeyedJaggedTensor(context, *values)
30633036

30643037

30653038
def _kjt_flatten_spec(
30663039
t: KeyedJaggedTensor, spec: TreeSpec
30673040
) -> List[Optional[torch.Tensor]]:
3068-
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
3069-
values.append(t._inverse_indices[1] if t._inverse_indices is not None else None)
3070-
3071-
return values
3041+
return [getattr(t, a) for a in KeyedJaggedTensor._fields]
30723042

30733043

30743044
register_pytree_node(

0 commit comments

Comments
 (0)