12
12
import functools
13
13
import itertools
14
14
import logging
15
+ import math
15
16
import os
16
17
import tempfile
17
18
import threading
@@ -172,6 +173,7 @@ def __init__(
172
173
res_params : Optional [RESParams ] = None , # raw embedding streaming sharding info
173
174
flushing_block_size : int = 2_000_000_000 , # 2GB
174
175
table_names : Optional [List [str ]] = None ,
176
+ optimizer_state_dtypes : Dict [str , SparseType ] = {}, # noqa: B006
175
177
) -> None :
176
178
super (SSDTableBatchedEmbeddingBags , self ).__init__ ()
177
179
@@ -185,6 +187,7 @@ def __init__(
185
187
assert weights_precision in (SparseType .FP32 , SparseType .FP16 )
186
188
self .weights_precision = weights_precision
187
189
self .output_dtype : int = output_dtype .as_int ()
190
+ self .optimizer_state_dtypes : Dict [str , SparseType ] = optimizer_state_dtypes
188
191
189
192
# Zero collision TBE configurations
190
193
self .kv_zch_params = kv_zch_params
@@ -987,13 +990,24 @@ def cache_row_dim(self) -> int:
987
990
"""
988
991
if self .enable_optimizer_offloading :
989
992
return self .max_D + pad4 (
990
- # Compute the number of elements of cache_dtype needed to store the
991
- # optimizer state
992
- self .optimizer . state_size_dim ( self . weights_precision . as_dtype ())
993
+ # Compute the number of elements of cache_dtype needed to store
994
+ # the optimizer state
995
+ self .optimizer_state_dim
993
996
)
994
997
else :
995
998
return self .max_D
996
999
1000
+ @cached_property
1001
+ def optimizer_state_dim (self ) -> int :
1002
+ return int (
1003
+ math .ceil (
1004
+ self .optimizer .state_size_nbytes (
1005
+ self .max_D , self .optimizer_state_dtypes
1006
+ )
1007
+ / self .weights_precision .as_dtype ().itemsize
1008
+ )
1009
+ )
1010
+
997
1011
@property
998
1012
# pyre-ignore
999
1013
def ssd_db (self ):
@@ -2285,9 +2299,8 @@ def split_optimizer_states(
2285
2299
table_offset = 0
2286
2300
2287
2301
dtype = self .weights_precision .as_dtype ()
2288
- optimizer_dim = self .optimizer .state_size_dim (dtype )
2289
2302
logging .info (
2290
- f"split_optimizer_states: { optimizer_dim = } , { self .optimizer .dtype ()= } { self .enable_load_state_dict_mode = } "
2303
+ f"split_optimizer_states: { self . optimizer_state_dim = } , { self .optimizer .dtype ()= } { self .enable_load_state_dict_mode = } "
2291
2304
)
2292
2305
2293
2306
for t , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
@@ -2345,7 +2358,7 @@ def split_optimizer_states(
2345
2358
and sorted_id_tensor [t ].size (0 ) > 0
2346
2359
else emb_height
2347
2360
),
2348
- optimizer_dim ,
2361
+ self . optimizer_state_dim ,
2349
2362
],
2350
2363
dtype = dtype ,
2351
2364
row_offset = row_offset ,
@@ -2373,7 +2386,7 @@ def split_optimizer_states(
2373
2386
# backend will return both weight and optimizer in one tensor, read the whole tensor
2374
2387
# out could OOM CPU memory.
2375
2388
tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2376
- shape = [emb_height , optimizer_dim ],
2389
+ shape = [emb_height , self . optimizer_state_dim ],
2377
2390
dtype = dtype ,
2378
2391
row_offset = row_offset ,
2379
2392
snapshot_handle = snapshot_handle ,
@@ -2652,11 +2665,7 @@ def split_embedding_weights(
2652
2665
(
2653
2666
metaheader_dim # metaheader is already padded
2654
2667
+ pad4 (emb_dim )
2655
- + pad4 (
2656
- self .optimizer .state_size_dim (
2657
- self .weights_precision .as_dtype ()
2658
- )
2659
- )
2668
+ + pad4 (self .optimizer_state_dim )
2660
2669
)
2661
2670
if self .backend_return_whole_row
2662
2671
else emb_dim
@@ -2802,8 +2811,7 @@ def streaming_write_weight_and_id_per_table(
2802
2811
# TODO: make chunk_size configurable or dynamic
2803
2812
chunk_size = 10000
2804
2813
row = weight_state .size (0 )
2805
- optimizer_dim = self .optimizer .state_size_dim (dtype )
2806
- opt_state_2d = opt_state .view (dtype ).view (- 1 , optimizer_dim )
2814
+ opt_state_2d = opt_state .view (dtype ).view (- 1 , self .optimizer_state_dim )
2807
2815
for i in range (0 , row , chunk_size ):
2808
2816
length = min (chunk_size , row - i )
2809
2817
chunk_buffer = torch .empty (
@@ -2813,9 +2821,9 @@ def streaming_write_weight_and_id_per_table(
2813
2821
device = "cpu" ,
2814
2822
)
2815
2823
chunk_buffer [:, : weight_state .size (1 )] = weight_state [i : i + length , :]
2816
- chunk_buffer [:, D_rounded : D_rounded + optimizer_dim ] = opt_state_2d [
2817
- i : i + length , :
2818
- ]
2824
+ chunk_buffer [:, D_rounded : D_rounded + self . optimizer_state_dim ] = (
2825
+ opt_state_2d [ i : i + length , :]
2826
+ )
2819
2827
kvt .set_weights_and_ids (chunk_buffer , id_tensor [i : i + length , :].view (- 1 ))
2820
2828
2821
2829
@torch .jit .ignore
@@ -3454,20 +3462,35 @@ def fetch_from_l1_sp_w_row_ids(
3454
3462
Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
3455
3463
@return: updated_weights/optimizer_states, mask of which rows are filled
3456
3464
"""
3465
+ if not self .enable_optimizer_offloading and only_get_optimizer_states :
3466
+ raise RuntimeError (
3467
+ "Optimizer states are not offloaded, while only_get_optimizer_states is True"
3468
+ )
3469
+
3470
+ # NOTE: Remove this once there is support for fetching multiple
3471
+ # optimizer states in fetch_from_l1_sp_w_row_ids
3472
+ if self .optimizer != OptimType .EXACT_ROWWISE_ADAGRAD :
3473
+ raise RuntimeError (
3474
+ "Only rowwise adagrad is supported in fetch_from_l1_sp_w_row_ids at the moment"
3475
+ )
3476
+
3457
3477
with torch .no_grad ():
3458
3478
weights_dtype = self .weights_precision .as_dtype ()
3459
3479
step = self .step
3460
- if not self .enable_optimizer_offloading and only_get_optimizer_states :
3461
- raise RuntimeError (
3462
- "Optimizer states are not offloaded, while only_get_optimizer_states is True"
3463
- )
3480
+
3464
3481
if only_get_optimizer_states :
3465
3482
start_pos = pad4 (self .max_D )
3466
- row_dim = self .optimizer .state_size_dim (weights_dtype )
3467
- result_dtype = self .optimizer .dtype ()
3483
+ # NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids working
3484
+ # until it is upgraded to support optimizers with multiple states
3485
+ # and dtypes
3486
+ row_dim = int (
3487
+ math .ceil (torch .float32 .itemsize / weights_dtype .itemsize )
3488
+ )
3489
+ result_dtype = torch .float32
3468
3490
result_dim = int (
3469
3491
ceil (row_dim / (result_dtype .itemsize / weights_dtype .itemsize ))
3470
3492
)
3493
+
3471
3494
else :
3472
3495
start_pos = 0
3473
3496
# get the whole row
0 commit comments