12
12
import functools
13
13
import itertools
14
14
import logging
15
+ import math
15
16
import os
16
17
import tempfile
17
18
import threading
@@ -172,6 +173,7 @@ def __init__(
172
173
res_params : Optional [RESParams ] = None , # raw embedding streaming sharding info
173
174
flushing_block_size : int = 2_000_000_000 , # 2GB
174
175
table_names : Optional [List [str ]] = None ,
176
+ optimizer_state_dtypes : Dict [str , SparseType ] = {}, # noqa: B006
175
177
) -> None :
176
178
super (SSDTableBatchedEmbeddingBags , self ).__init__ ()
177
179
@@ -185,6 +187,7 @@ def __init__(
185
187
assert weights_precision in (SparseType .FP32 , SparseType .FP16 )
186
188
self .weights_precision = weights_precision
187
189
self .output_dtype : int = output_dtype .as_int ()
190
+ self .optimizer_state_dtypes : Dict [str , SparseType ] = optimizer_state_dtypes
188
191
189
192
# Zero collision TBE configurations
190
193
self .kv_zch_params = kv_zch_params
@@ -987,13 +990,24 @@ def cache_row_dim(self) -> int:
987
990
"""
988
991
if self .enable_optimizer_offloading :
989
992
return self .max_D + pad4 (
990
- # Compute the number of elements of cache_dtype needed to store the
991
- # optimizer state
992
- self .optimizer . state_size_dim ( self . weights_precision . as_dtype ())
993
+ # Compute the number of elements of cache_dtype needed to store
994
+ # the optimizer state
995
+ self .optimizer_state_dim
993
996
)
994
997
else :
995
998
return self .max_D
996
999
1000
+ @cached_property
1001
+ def optimizer_state_dim (self ) -> int :
1002
+ return int (
1003
+ math .ceil (
1004
+ self .optimizer .state_size_nbytes (
1005
+ self .max_D , self .optimizer_state_dtypes
1006
+ )
1007
+ / self .weights_precision .as_dtype ().itemsize
1008
+ )
1009
+ )
1010
+
997
1011
@property
998
1012
# pyre-ignore
999
1013
def ssd_db (self ):
@@ -2285,9 +2299,8 @@ def split_optimizer_states(
2285
2299
table_offset = 0
2286
2300
2287
2301
dtype = self .weights_precision .as_dtype ()
2288
- optimizer_dim = self .optimizer .state_size_dim (dtype )
2289
2302
logging .info (
2290
- f"split_optimizer_states: { optimizer_dim = } , { self .optimizer .dtype ()= } { self .enable_load_state_dict_mode = } "
2303
+ f"split_optimizer_states: { self . optimizer_state_dim = } , { self .optimizer .dtype ()= } { self .enable_load_state_dict_mode = } "
2291
2304
)
2292
2305
2293
2306
for t , (emb_height , emb_dim ) in enumerate (self .embedding_specs ):
@@ -2373,7 +2386,7 @@ def split_optimizer_states(
2373
2386
# backend will return both weight and optimizer in one tensor, read the whole tensor
2374
2387
# out could OOM CPU memory.
2375
2388
tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2376
- shape = [emb_height , optimizer_dim ],
2389
+ shape = [emb_height , self . optimizer_state_dim ],
2377
2390
dtype = dtype ,
2378
2391
row_offset = row_offset ,
2379
2392
snapshot_handle = snapshot_handle ,
@@ -2802,8 +2815,7 @@ def streaming_write_weight_and_id_per_table(
2802
2815
# TODO: make chunk_size configurable or dynamic
2803
2816
chunk_size = 10000
2804
2817
row = weight_state .size (0 )
2805
- optimizer_dim = self .optimizer .state_size_dim (dtype )
2806
- opt_state_2d = opt_state .view (dtype ).view (- 1 , optimizer_dim )
2818
+ opt_state_2d = opt_state .view (dtype ).view (- 1 , self .optimizer_state_dim )
2807
2819
for i in range (0 , row , chunk_size ):
2808
2820
length = min (chunk_size , row - i )
2809
2821
chunk_buffer = torch .empty (
@@ -2813,9 +2825,9 @@ def streaming_write_weight_and_id_per_table(
2813
2825
device = "cpu" ,
2814
2826
)
2815
2827
chunk_buffer [:, : weight_state .size (1 )] = weight_state [i : i + length , :]
2816
- chunk_buffer [:, D_rounded : D_rounded + optimizer_dim ] = opt_state_2d [
2817
- i : i + length , :
2818
- ]
2828
+ chunk_buffer [:, D_rounded : D_rounded + self . optimizer_state_dim ] = (
2829
+ opt_state_2d [ i : i + length , :]
2830
+ )
2819
2831
kvt .set_weights_and_ids (chunk_buffer , id_tensor [i : i + length , :].view (- 1 ))
2820
2832
2821
2833
@torch .jit .ignore
@@ -3454,20 +3466,35 @@ def fetch_from_l1_sp_w_row_ids(
3454
3466
Fetch the optimizer states and/or weights from L1 and SP for given linearized row_ids.
3455
3467
@return: updated_weights/optimizer_states, mask of which rows are filled
3456
3468
"""
3469
+ if not self .enable_optimizer_offloading and only_get_optimizer_states :
3470
+ raise RuntimeError (
3471
+ "Optimizer states are not offloaded, while only_get_optimizer_states is True"
3472
+ )
3473
+
3474
+ # NOTE: Remove this once there is support for fetching multiple
3475
+ # optimizer states in fetch_from_l1_sp_w_row_ids
3476
+ if self .optimizer != OptimType .EXACT_ROWWISE_ADAGRAD :
3477
+ raise RuntimeError (
3478
+ "Only rowwise adagrad is supported in fetch_from_l1_sp_w_row_ids at the moment"
3479
+ )
3480
+
3457
3481
with torch .no_grad ():
3458
3482
weights_dtype = self .weights_precision .as_dtype ()
3459
3483
step = self .step
3460
- if not self .enable_optimizer_offloading and only_get_optimizer_states :
3461
- raise RuntimeError (
3462
- "Optimizer states are not offloaded, while only_get_optimizer_states is True"
3463
- )
3484
+
3464
3485
if only_get_optimizer_states :
3465
3486
start_pos = pad4 (self .max_D )
3466
- row_dim = self .optimizer .state_size_dim (weights_dtype )
3467
- result_dtype = self .optimizer .dtype ()
3487
+ # NOTE: This is a hack to keep fetch_from_l1_sp_w_row_ids working
3488
+ # until it is upgraded to support optimizers with multiple states
3489
+ # and dtypes
3490
+ row_dim = int (
3491
+ math .ceil (torch .float32 .itemsize / weights_dtype .itemsize )
3492
+ )
3493
+ result_dtype = torch .float32
3468
3494
result_dim = int (
3469
3495
ceil (row_dim / (result_dtype .itemsize / weights_dtype .itemsize ))
3470
3496
)
3497
+
3471
3498
else :
3472
3499
start_pos = 0
3473
3500
# get the whole row
0 commit comments