@@ -757,7 +757,7 @@ def execute_ssd_forward_(
757
757
758
758
def split_optimizer_states_ (
759
759
self , emb : SSDTableBatchedEmbeddingBags
760
- ) -> List [torch .Tensor ]:
760
+ ) -> List [List [ torch .Tensor ] ]:
761
761
_ , bucket_asc_ids_list , _ = emb .split_embedding_weights (
762
762
no_snapshot = False , should_flush = True
763
763
)
@@ -962,7 +962,7 @@ def test_ssd_backward_adagrad(
962
962
# pyre-fixme[16]: Optional type has no attribute `float`.
963
963
ref_optimizer_state = emb_ref [f ].weight .grad .float ().to_dense ().pow (2 )
964
964
torch .testing .assert_close (
965
- split_optimizer_states [t ].float (),
965
+ split_optimizer_states [t ][ 0 ] .float (),
966
966
ref_optimizer_state .mean (dim = 1 ),
967
967
atol = tolerance ,
968
968
rtol = tolerance ,
@@ -978,7 +978,7 @@ def test_ssd_backward_adagrad(
978
978
emb_r .weight .float (),
979
979
value = - lr ,
980
980
tensor1 = emb_r .weight .grad .float ().to_dense (),
981
- tensor2 = split_optimizer_states [t ]
981
+ tensor2 = split_optimizer_states [t ][ 0 ]
982
982
.float ()
983
983
.sqrt_ ()
984
984
.add_ (eps )
@@ -1113,7 +1113,10 @@ def test_ssd_emb_state_dict(
1113
1113
emb_r .weight .float (),
1114
1114
value = - lr ,
1115
1115
tensor1 = emb_r .weight .grad .float ().to_dense (), # pyre-ignore[16]
1116
- tensor2 = split_optimizer_states [table_index ]
1116
+ # NOTE: The [0] index is a hack since the test is fixed to use
1117
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
1118
+ # be upgraded in the future to support multiple optimizers
1119
+ tensor2 = split_optimizer_states [table_index ][0 ]
1117
1120
.float ()
1118
1121
.sqrt_ ()
1119
1122
.add_ (eps )
@@ -1188,7 +1191,8 @@ def execute_ssd_cache_pipeline_( # noqa C901
1188
1191
)
1189
1192
1190
1193
optimizer_states_ref = [
1191
- s .clone ().float () for s in self .split_optimizer_states_ (emb )
1194
+ [s .clone ().float () for s in states ]
1195
+ for states in self .split_optimizer_states_ (emb )
1192
1196
]
1193
1197
1194
1198
Es = [emb .embedding_specs [t ][0 ] for t in range (T )]
@@ -1334,8 +1338,11 @@ def _prefetch(b_it: int) -> int:
1334
1338
# Compare optimizer states
1335
1339
split_optimizer_states = self .split_optimizer_states_ (emb )
1336
1340
for f , t in self .get_physical_table_arg_indices_ (emb .feature_table_map ):
1337
- optim_state_r = optimizer_states_ref [t ]
1338
- optim_state_t = split_optimizer_states [t ]
1341
+ optim_state_r = optimizer_states_ref [t ][0 ]
1342
+ # NOTE: The [0] index is a hack since the test is fixed to use
1343
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
1344
+ # be upgraded in the future to support multiple optimizers
1345
+ optim_state_t = split_optimizer_states [t ][0 ]
1339
1346
emb_r = emb_ref [f ]
1340
1347
1341
1348
optim_state_r .add_ (
@@ -1753,7 +1760,10 @@ def test_kv_emb_state_dict(
1753
1760
dim = 1
1754
1761
)
1755
1762
torch .testing .assert_close (
1756
- split_optimizer_states [t ].float (),
1763
+ # NOTE: The [0] index is a hack since the test is fixed to use
1764
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
1765
+ # be upgraded in the future to support multiple optimizers
1766
+ split_optimizer_states [t ][0 ].float (),
1757
1767
ref_opt_mean .cpu (),
1758
1768
atol = tolerance ,
1759
1769
rtol = tolerance ,
@@ -1799,8 +1809,11 @@ def test_kv_emb_state_dict(
1799
1809
.to_dense ()[bucket_asc_ids_list [table_index ].view (- 1 )]
1800
1810
)
1801
1811
self .assertLess (table_index , len (emb_state_dict_list ))
1802
- assert len (split_optimizer_states [table_index ]) == num_ids
1803
- opt = split_optimizer_states [table_index ]
1812
+ assert len (split_optimizer_states [table_index ][0 ]) == num_ids
1813
+ # NOTE: The [0] index is a hack since the test is fixed to use
1814
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
1815
+ # be upgraded in the future to support multiple optimizers
1816
+ opt = split_optimizer_states [table_index ][0 ]
1804
1817
new_ref_weight = torch .addcdiv (
1805
1818
emb_r_w .float (),
1806
1819
value = - lr ,
@@ -1985,7 +1998,10 @@ def test_kv_opt_state_w_offloading(
1985
1998
# pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `__getitem__`.
1986
1999
ref_kv_opt = ref_optimizer_state [bucket_asc_ids_list [t ]].view (- 1 )
1987
2000
torch .testing .assert_close (
1988
- split_optimizer_states [t ].float (),
2001
+ # NOTE: The [0] index is a hack since the test is fixed to use
2002
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2003
+ # be upgraded in the future to support multiple optimizers
2004
+ split_optimizer_states [t ][0 ].float (),
1989
2005
ref_kv_opt ,
1990
2006
atol = tolerance ,
1991
2007
rtol = tolerance ,
@@ -2031,8 +2047,11 @@ def test_kv_opt_state_w_offloading(
2031
2047
.to_dense ()[bucket_asc_ids_list [table_index ].view (- 1 )]
2032
2048
)
2033
2049
self .assertLess (table_index , len (emb_state_dict_list ))
2034
- assert len (split_optimizer_states [table_index ]) == num_ids
2035
- opt = split_optimizer_states [table_index ]
2050
+ # NOTE: The [0] index is a hack since the test is fixed to use
2051
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2052
+ # be upgraded in the future to support multiple optimizers
2053
+ assert len (split_optimizer_states [table_index ][0 ]) == num_ids
2054
+ opt = split_optimizer_states [table_index ][0 ]
2036
2055
new_ref_weight = torch .addcdiv (
2037
2056
emb_r_w .float (),
2038
2057
value = - lr ,
@@ -2221,7 +2240,10 @@ def test_kv_state_dict_w_backend_return_whole_row(
2221
2240
# pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `__getitem__`.
2222
2241
ref_kv_opt = ref_optimizer_state [bucket_asc_ids_list [t ]].view (- 1 )
2223
2242
opt = (
2224
- split_optimizer_states [t ]
2243
+ # NOTE: The [0] index is a hack since the test is fixed to use
2244
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2245
+ # be upgraded in the future to support multiple optimizers
2246
+ split_optimizer_states [t ][0 ]
2225
2247
.narrow (0 , 0 , bucket_asc_ids_list [t ].size (0 ))
2226
2248
.view (- 1 )
2227
2249
.view (torch .float32 )
@@ -2276,7 +2298,10 @@ def test_kv_state_dict_w_backend_return_whole_row(
2276
2298
.to_dense ()[bucket_asc_ids_list [table_index ].view (- 1 )]
2277
2299
)
2278
2300
self .assertLess (table_index , len (emb_state_dict_list ))
2279
- assert split_optimizer_states [table_index ].size (0 ) == num_ids
2301
+ # NOTE: The [0] index is a hack since the test is fixed to use
2302
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2303
+ # be upgraded in the future to support multiple optimizers
2304
+ assert split_optimizer_states [table_index ][0 ].size (0 ) == num_ids
2280
2305
new_ref_weight = torch .addcdiv (
2281
2306
emb_r_w .float (),
2282
2307
value = - lr ,
@@ -2501,9 +2526,12 @@ def test_apply_kv_state_dict(
2501
2526
# pyre-fixme[16]: Undefined attribute: Item `torch._tensor.Tensor` of `typing.Uni...
2502
2527
emb_state_dict_list [i ].full_tensor ()
2503
2528
)
2529
+ # NOTE: The [0] index is a hack since the test is fixed to use
2530
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2531
+ # be upgraded in the future to support multiple optimizers
2504
2532
# pyre-ignore [16]
2505
2533
emb2 ._cached_kvzch_data .cached_optimizer_state_per_table [i ].copy_ (
2506
- split_optimizer_states [i ]
2534
+ split_optimizer_states [i ][ 0 ]
2507
2535
)
2508
2536
# pyre-ignore [16]
2509
2537
emb2 ._cached_kvzch_data .cached_id_tensor_per_table [i ].copy_ (
@@ -2547,8 +2575,8 @@ def test_apply_kv_state_dict(
2547
2575
rtol = tolerance ,
2548
2576
)
2549
2577
torch .testing .assert_close (
2550
- split_optimizer_states [t ][sorted_ids .indices ],
2551
- split_optimizer_states2 [t ][sorted_ids2 .indices ],
2578
+ split_optimizer_states [t ][0 ][ sorted_ids .indices ],
2579
+ split_optimizer_states2 [t ][0 ][ sorted_ids2 .indices ],
2552
2580
atol = tolerance ,
2553
2581
rtol = tolerance ,
2554
2582
)
@@ -2820,7 +2848,10 @@ def copy_weights_hook(
2820
2848
# pyre-fixme[16]: Optional type has no attribute `float`.
2821
2849
ref_optimizer_state = emb_ref [f ].weight .grad .float ().to_dense ().pow (2 )
2822
2850
torch .testing .assert_close (
2823
- split_optimizer_states [t ].float (),
2851
+ # NOTE: The [0] index is a hack since the test is fixed to use
2852
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
2853
+ # be upgraded in the future to support multiple optimizers
2854
+ split_optimizer_states [t ][0 ].float (),
2824
2855
ref_optimizer_state .mean (dim = 1 ),
2825
2856
atol = tolerance ,
2826
2857
rtol = tolerance ,
@@ -3036,7 +3067,10 @@ def copy_opt_states_hook(
3036
3067
cursor += local_idxes .numel ()
3037
3068
3038
3069
torch .testing .assert_close (
3039
- split_optimizer_states [t ][indices ].float (),
3070
+ # NOTE: The [0] index is a hack since the test is fixed to use
3071
+ # EXACT_ROWWISE_ADAGRAD optimizer. The test in general should
3072
+ # be upgraded in the future to support multiple optimizers
3073
+ split_optimizer_states [t ][0 ][indices ].float (),
3040
3074
opt_states_per_tb .cpu ().float (),
3041
3075
atol = tolerance ,
3042
3076
rtol = tolerance ,
0 commit comments