@@ -207,14 +207,8 @@ def forward(
207
207
num_embeddings = 10 ,
208
208
feature_names = ["f2" ],
209
209
)
210
- config3 = EmbeddingBagConfig (
211
- name = "t3" ,
212
- embedding_dim = 5 ,
213
- num_embeddings = 10 ,
214
- feature_names = ["f3" ],
215
- )
216
210
ebc = EmbeddingBagCollection (
217
- tables = [config1 , config2 , config3 ],
211
+ tables = [config1 , config2 ],
218
212
is_weighted = False ,
219
213
)
220
214
@@ -299,60 +293,42 @@ def test_serialize_deserialize_ebc(self) -> None:
299
293
self .assertEqual (deserialized .shape , orginal .shape )
300
294
self .assertTrue (torch .allclose (deserialized , orginal ))
301
295
296
+ @unittest .skip ("Adding test for demonstrating VBE KJT flattening issue for now." )
302
297
def test_serialize_deserialize_ebc_with_vbe_kjt (self ) -> None :
303
298
model = self .generate_model_for_vbe_kjt ()
304
- kjt_1 = KeyedJaggedTensor (
305
- keys = ["f1" , "f2" , "f3" ],
306
- values = torch .tensor ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ]),
307
- lengths = torch .tensor ([1 , 2 , 3 , 2 , 1 , 1 ]),
308
- stride_per_key_per_rank = torch .tensor ([[3 ], [2 ], [1 ]]),
309
- inverse_indices = (
310
- ["f1" , "f2" , "f3" ],
311
- torch .tensor ([[0 , 1 , 2 ], [0 , 1 , 0 ], [0 , 0 , 0 ]]),
312
- ),
313
- )
314
- kjt_2 = KeyedJaggedTensor (
315
- keys = ["f1" , "f2" , "f3" ],
316
- values = torch .tensor ([1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 0 ]),
317
- lengths = torch .tensor ([1 , 2 , 3 , 2 , 1 , 1 ]),
318
- stride_per_key_per_rank = torch .tensor ([[1 ], [2 ], [3 ]]),
319
- inverse_indices = (
320
- ["f1" , "f2" , "f3" ],
321
- torch .tensor ([[0 , 0 , 0 ], [0 , 1 , 0 ], [0 , 1 , 2 ]]),
322
- ),
299
+ id_list_features = KeyedJaggedTensor (
300
+ keys = ["f1" , "f2" ],
301
+ values = torch .tensor ([5 , 6 , 7 , 1 , 2 , 3 , 0 , 1 ]),
302
+ lengths = torch .tensor ([3 , 3 , 2 ]),
303
+ stride_per_key_per_rank = [[2 ], [1 ]],
304
+ inverse_indices = (["f1" , "f2" ], torch .tensor ([[0 , 1 , 0 ], [0 , 0 , 0 ]])),
323
305
)
324
306
325
- eager_out = model (kjt_1 )
326
- eager_out_2 = model (kjt_2 )
307
+ eager_out = model (id_list_features )
327
308
328
309
# Serialize EBC
329
310
model , sparse_fqns = encapsulate_ir_modules (model , JsonSerializer )
330
311
ep = torch .export .export (
331
312
model ,
332
- (kjt_1 ,),
313
+ (id_list_features ,),
333
314
{},
334
315
strict = False ,
335
316
# Allows KJT to not be unflattened and run a forward on unflattened EP
336
317
preserve_module_call_signature = (tuple (sparse_fqns )),
337
318
)
338
319
339
320
# Run forward on ExportedProgram
340
- ep_output = ep .module ()(kjt_1 )
341
- ep_output_2 = ep .module ()(kjt_2 )
321
+ ep_output = ep .module ()(id_list_features )
342
322
343
- self .assertEqual (len (ep_output ), len (kjt_1 .keys ()))
344
- self .assertEqual (len (ep_output_2 ), len (kjt_2 .keys ()))
345
323
for i , tensor in enumerate (ep_output ):
346
- self .assertEqual (eager_out [i ].shape [1 ], tensor .shape [1 ])
347
- for i , tensor in enumerate (ep_output_2 ):
348
- self .assertEqual (eager_out_2 [i ].shape [1 ], tensor .shape [1 ])
324
+ self .assertEqual (eager_out [i ].shape , tensor .shape )
349
325
350
326
# Deserialize EBC
351
327
unflatten_ep = torch .export .unflatten (ep )
352
328
deserialized_model = decapsulate_ir_modules (unflatten_ep , JsonSerializer )
353
329
354
330
# check EBC config
355
- for i in range (1 ):
331
+ for i in range (5 ):
356
332
ebc_name = f"ebc{ i + 1 } "
357
333
self .assertIsInstance (
358
334
getattr (deserialized_model , ebc_name ), EmbeddingBagCollection
@@ -367,22 +343,36 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
367
343
self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
368
344
self .assertEqual (deserialized .feature_names , orginal .feature_names )
369
345
346
+ # check FPEBC config
347
+ for i in range (2 ):
348
+ fpebc_name = f"fpebc{ i + 1 } "
349
+ assert isinstance (
350
+ getattr (deserialized_model , fpebc_name ),
351
+ FeatureProcessedEmbeddingBagCollection ,
352
+ )
353
+
354
+ for deserialized , orginal in zip (
355
+ getattr (
356
+ deserialized_model , fpebc_name
357
+ )._embedding_bag_collection .embedding_bag_configs (),
358
+ getattr (
359
+ model , fpebc_name
360
+ )._embedding_bag_collection .embedding_bag_configs (),
361
+ ):
362
+ self .assertEqual (deserialized .name , orginal .name )
363
+ self .assertEqual (deserialized .embedding_dim , orginal .embedding_dim )
364
+ self .assertEqual (deserialized .num_embeddings , orginal .num_embeddings )
365
+ self .assertEqual (deserialized .feature_names , orginal .feature_names )
366
+
370
367
# Run forward on deserialized model and compare the output
371
368
deserialized_model .load_state_dict (model .state_dict ())
372
- deserialized_out = deserialized_model (kjt_1 )
369
+ deserialized_out = deserialized_model (id_list_features )
373
370
374
371
self .assertEqual (len (deserialized_out ), len (eager_out ))
375
372
for deserialized , orginal in zip (deserialized_out , eager_out ):
376
373
self .assertEqual (deserialized .shape , orginal .shape )
377
374
self .assertTrue (torch .allclose (deserialized , orginal ))
378
375
379
- deserialized_out_2 = deserialized_model (kjt_2 )
380
-
381
- self .assertEqual (len (deserialized_out_2 ), len (eager_out_2 ))
382
- for deserialized , orginal in zip (deserialized_out_2 , eager_out_2 ):
383
- self .assertEqual (deserialized .shape , orginal .shape )
384
- self .assertTrue (torch .allclose (deserialized , orginal ))
385
-
386
376
def test_dynamic_shape_ebc_disabled_in_oss_compatibility (self ) -> None :
387
377
model = self .generate_model ()
388
378
feature1 = KeyedJaggedTensor .from_offsets_sync (
0 commit comments