@@ -190,15 +190,27 @@ def __init__(
190
190
self .kv_zch_params = kv_zch_params
191
191
self .backend_type = backend_type
192
192
self .enable_optimizer_offloading : bool = False
193
+ self .backend_return_whole_row : bool = False
193
194
if self .kv_zch_params :
194
195
self .kv_zch_params .validate ()
195
196
self .enable_optimizer_offloading = (
196
197
# pyre-ignore [16]
197
198
self .kv_zch_params .enable_optimizer_offloading
198
199
)
200
+ self .backend_return_whole_row = (
201
+ # pyre-ignore [16]
202
+ self .kv_zch_params .backend_return_whole_row
203
+ )
199
204
200
205
if self .enable_optimizer_offloading :
201
206
logging .info ("Optimizer state offloading is enabled" )
207
+ if self .backend_return_whole_row :
208
+ assert (
209
+ self .backend_type == BackendType .DRAM
210
+ ), f"Only DRAM backend supports backend_return_whole_row, but got { self .backend_type } "
211
+ logging .info (
212
+ "Backend will return whole row including metaheader, weight and optimizer for checkpoint"
213
+ )
202
214
203
215
self .pooling_mode = pooling_mode
204
216
self .bounds_check_mode_int : int = bounds_check_mode .value
@@ -625,13 +637,14 @@ def __init__(
625
637
logging .info (
626
638
f"Logging DRAM offloading setup, tbe_unique_id:{ tbe_unique_id } , l2_cache_size:{ l2_cache_size } GB,"
627
639
f"num_shards={ ssd_rocksdb_shards } ,num_threads={ ssd_rocksdb_shards } ,"
628
- f"max_D={ self .max_D } "
640
+ f"max_D={ self .max_D } , "
629
641
f"uniform_init_lower={ ssd_uniform_init_lower } ,uniform_init_upper={ ssd_uniform_init_upper } ,"
630
642
f"row_storage_bitwidth={ weights_precision .bit_rate ()} ,"
631
643
f"self.cache_row_dim={ self .cache_row_dim } ,"
632
644
f"enable_optimizer_offloading={ self .enable_optimizer_offloading } ,"
633
645
f"feature_dims={ self .feature_dims } ,"
634
- f"hash_size_cumsum={ self .hash_size_cumsum } "
646
+ f"hash_size_cumsum={ self .hash_size_cumsum } ,"
647
+ f"backend_return_whole_row={ self .backend_return_whole_row } "
635
648
)
636
649
table_dims = (
637
650
tensor_pad4 (self .table_dims )
@@ -672,6 +685,7 @@ def __init__(
672
685
if self .enable_optimizer_offloading
673
686
else None
674
687
), # hash_size_cumsum
688
+ self .backend_return_whole_row , # backend_return_whole_row
675
689
)
676
690
else :
677
691
raise AssertionError (f"Invalid backend type { self .backend_type } " )
@@ -2282,16 +2296,19 @@ def split_optimizer_states(
2282
2296
# pyre-ignore
2283
2297
bucket_size = self .kv_zch_params .bucket_sizes [t ]
2284
2298
row_offset = table_offset
2285
- if sorted_id_tensor is None or sorted_id_tensor [t ].numel () == 0 :
2299
+ if not self .backend_return_whole_row and (
2300
+ sorted_id_tensor is None or sorted_id_tensor [t ].numel () == 0
2301
+ ):
2286
2302
opt_list .append (
2287
2303
torch .empty (0 , dtype = self .optimizer .dtype (), device = "cpu" )
2288
2304
# empty optimizer state for module initialization
2305
+ # which will NOT be used for cp loading
2289
2306
)
2290
2307
else :
2291
2308
if not self .enable_optimizer_offloading :
2292
2309
# convert global id back to local id, then linearize with table offset
2293
2310
local_id_tensor = (
2294
- sorted_id_tensor [t ]
2311
+ sorted_id_tensor [t ] # pyre-ignore[16]
2295
2312
- bucket_id_start * bucket_size
2296
2313
+ table_offset
2297
2314
)
@@ -2300,27 +2317,79 @@ def split_optimizer_states(
2300
2317
)
2301
2318
else :
2302
2319
row_offset = table_offset - (bucket_id_start * bucket_size )
2303
- # using KVTensorWrapper to query backend to avoid OOM memory, since
2304
- # backend will return both weight and optimizer in one tensor, read the whole tensor
2305
- # out could OOM CPU memory.
2306
- tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2307
- shape = [emb_height , optimizer_dim ],
2308
- dtype = dtype ,
2309
- row_offset = row_offset ,
2310
- snapshot_handle = snapshot_handle ,
2311
- sorted_indices = sorted_id_tensor [t ],
2312
- width_offset = pad4 (emb_dim ),
2313
- )
2314
- (
2315
- tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2316
- if self .backend_type == BackendType .SSD
2317
- else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2318
- )
2319
- opt_list .append (
2320
- self .get_offloaded_optimizer_states (
2321
- tensor_wrapper , sorted_id_tensor [t ].numel ()
2320
+ if self .backend_return_whole_row :
2321
+ # When backend returns whole row, the optimizer will be returned as PMT directly
2322
+ if (
2323
+ sorted_id_tensor [t ].size (0 ) == 0
2324
+ and self .local_weight_counts [t ] > 0
2325
+ ):
2326
+ logging .info (
2327
+ f"before opt PMT loading, resetting id tensor with { self .local_weight_counts [t ]} "
2328
+ )
2329
+ # pyre-ignore [16]
2330
+ sorted_id_tensor [t ] = torch .zeros (
2331
+ (self .local_weight_counts [t ], 1 ),
2332
+ device = torch .device ("cpu" ),
2333
+ dtype = torch .int64 ,
2334
+ )
2335
+
2336
+ metaheader_dim = (
2337
+ # pyre-ignore[16]
2338
+ self .kv_zch_params .eviction_policy .meta_header_lens [t ]
2339
+ )
2340
+ tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2341
+ shape = [
2342
+ (
2343
+ sorted_id_tensor [t ].size (0 )
2344
+ if sorted_id_tensor is not None
2345
+ and sorted_id_tensor [t ].size (0 ) > 0
2346
+ else emb_height
2347
+ ),
2348
+ optimizer_dim ,
2349
+ ],
2350
+ dtype = dtype ,
2351
+ row_offset = row_offset ,
2352
+ snapshot_handle = snapshot_handle ,
2353
+ sorted_indices = sorted_id_tensor [t ],
2354
+ width_offset = (
2355
+ metaheader_dim # metaheader is already padded so no need for pad4
2356
+ + pad4 (emb_dim )
2357
+ ),
2358
+ read_only = True , # optimizer written to DB with weights, so skip write here
2359
+ )
2360
+ (
2361
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2362
+ if self .backend_type == BackendType .SSD
2363
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2364
+ )
2365
+ opt_list .append (
2366
+ PartiallyMaterializedTensor (
2367
+ tensor_wrapper ,
2368
+ True if self .kv_zch_params else False ,
2369
+ )
2370
+ )
2371
+ else :
2372
+ # using KVTensorWrapper to query backend to avoid OOM memory, since
2373
+ # backend will return both weight and optimizer in one tensor, read the whole tensor
2374
+ # out could OOM CPU memory.
2375
+ tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2376
+ shape = [emb_height , optimizer_dim ],
2377
+ dtype = dtype ,
2378
+ row_offset = row_offset ,
2379
+ snapshot_handle = snapshot_handle ,
2380
+ sorted_indices = sorted_id_tensor [t ],
2381
+ width_offset = pad4 (emb_dim ),
2382
+ )
2383
+ (
2384
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2385
+ if self .backend_type == BackendType .SSD
2386
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2387
+ )
2388
+ opt_list .append (
2389
+ self .get_offloaded_optimizer_states (
2390
+ tensor_wrapper , sorted_id_tensor [t ].numel ()
2391
+ )
2322
2392
)
2323
- )
2324
2393
table_offset += emb_height
2325
2394
logging .info (
2326
2395
f"KV ZCH tables split_optimizer_states query latency: { (time .time () - start_time ) * 1000 } ms, "
@@ -2515,10 +2584,15 @@ def split_embedding_weights(
2515
2584
bucket_ascending_id_tensor = None
2516
2585
bucket_t = None
2517
2586
row_offset = table_offset
2587
+ metaheader_dim = 0
2518
2588
if self .kv_zch_params :
2519
2589
bucket_id_start , bucket_id_end = self .kv_zch_params .bucket_offsets [i ]
2520
2590
# pyre-ignore
2521
2591
bucket_size = self .kv_zch_params .bucket_sizes [i ]
2592
+ metaheader_dim = (
2593
+ # pyre-ignore[16]
2594
+ self .kv_zch_params .eviction_policy .meta_header_lens [i ]
2595
+ )
2522
2596
2523
2597
# linearize with table offset
2524
2598
table_input_id_start = table_offset
@@ -2548,7 +2622,7 @@ def split_embedding_weights(
2548
2622
and self .local_weight_counts [i ] > 0
2549
2623
):
2550
2624
logging .info (
2551
- f"resetting bucket id tensor with { self .local_weight_counts [i ]} "
2625
+ f"before weight PMT loading, resetting id tensor with { self .local_weight_counts [i ]} "
2552
2626
)
2553
2627
bucket_ascending_id_tensor = torch .zeros (
2554
2628
(self .local_weight_counts [i ], 1 ),
@@ -2574,7 +2648,19 @@ def split_embedding_weights(
2574
2648
if bucket_ascending_id_tensor is not None
2575
2649
else emb_height
2576
2650
),
2577
- emb_dim ,
2651
+ (
2652
+ (
2653
+ metaheader_dim # metaheader is already padded
2654
+ + pad4 (emb_dim )
2655
+ + pad4 (
2656
+ self .optimizer .state_size_dim (
2657
+ self .weights_precision .as_dtype ()
2658
+ )
2659
+ )
2660
+ )
2661
+ if self .backend_return_whole_row
2662
+ else emb_dim
2663
+ ),
2578
2664
],
2579
2665
dtype = dtype ,
2580
2666
row_offset = row_offset ,
@@ -2611,6 +2697,11 @@ def split_embedding_weights(
2611
2697
2612
2698
@torch .jit .ignore
2613
2699
def apply_state_dict (self ) -> None :
2700
+ if self .backend_return_whole_row :
2701
+ logging .info (
2702
+ "backend_return_whole_row is enabled, no need to apply_state_dict"
2703
+ )
2704
+ return
2614
2705
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
2615
2706
# Caller should call this function to apply the cached states to backend.
2616
2707
if self .load_state_dict is False :
@@ -2729,6 +2820,11 @@ def streaming_write_weight_and_id_per_table(
2729
2820
2730
2821
@torch .jit .ignore
2731
2822
def enable_load_state_dict_mode (self ) -> None :
2823
+ if self .backend_return_whole_row :
2824
+ logging .info (
2825
+ "backend_return_whole_row is enabled, no need to enable load_state_dict mode"
2826
+ )
2827
+ return
2732
2828
# Enable load state dict mode before loading checkpoint
2733
2829
if self .load_state_dict :
2734
2830
return
0 commit comments