@@ -187,15 +187,30 @@ def __init__(
187
187
self .kv_zch_params = kv_zch_params
188
188
self .backend_type = backend_type
189
189
self .enable_optimizer_offloading : bool = False
190
+ self .backend_return_whole_row : bool = False
190
191
if self .kv_zch_params :
191
192
self .kv_zch_params .validate ()
192
193
self .enable_optimizer_offloading = (
193
194
# pyre-ignore [16]
194
195
self .kv_zch_params .enable_optimizer_offloading
195
196
)
197
+ self .backend_return_whole_row = (
198
+ # pyre-ignore [16]
199
+ self .kv_zch_params .backend_return_whole_row
200
+ )
196
201
197
202
if self .enable_optimizer_offloading :
198
203
logging .info ("Optimizer state offloading is enabled" )
204
+ if self .backend_return_whole_row :
205
+ assert (
206
+ self .backend_type == BackendType .DRAM
207
+ ), f"Only DRAM backend supports backend_return_whole_row, but got { self .backend_type } "
208
+ logging .info (
209
+ "Backend will return whole row including metaheader, weight and optimizer for checkpoint"
210
+ )
211
+
212
+ # TODO: the metaheader is 16 bytes fixed.
213
+ self .metaheader_dim : int = 16 // (weights_precision .bit_rate () // 8 )
199
214
200
215
self .pooling_mode = pooling_mode
201
216
self .bounds_check_mode_int : int = bounds_check_mode .value
@@ -612,13 +627,14 @@ def __init__(
612
627
logging .info (
613
628
f"Logging DRAM offloading setup, tbe_unique_id:{ tbe_unique_id } , l2_cache_size:{ l2_cache_size } GB,"
614
629
f"num_shards={ ssd_rocksdb_shards } ,num_threads={ ssd_rocksdb_shards } ,"
615
- f"max_D={ self .max_D } "
630
+ f"max_D={ self .max_D } , "
616
631
f"uniform_init_lower={ ssd_uniform_init_lower } ,uniform_init_upper={ ssd_uniform_init_upper } ,"
617
632
f"row_storage_bitwidth={ weights_precision .bit_rate ()} ,"
618
633
f"self.cache_row_dim={ self .cache_row_dim } ,"
619
634
f"enable_optimizer_offloading={ self .enable_optimizer_offloading } ,"
620
635
f"feature_dims={ self .feature_dims } ,"
621
- f"hash_size_cumsum={ self .hash_size_cumsum } "
636
+ f"hash_size_cumsum={ self .hash_size_cumsum } ,"
637
+ f"backend_return_whole_row={ self .backend_return_whole_row } "
622
638
)
623
639
table_dims = (
624
640
tensor_pad4 (self .table_dims )
@@ -659,6 +675,7 @@ def __init__(
659
675
if self .enable_optimizer_offloading
660
676
else None
661
677
), # hash_size_cumsum
678
+ self .backend_return_whole_row , # backend_return_whole_row
662
679
)
663
680
else :
664
681
raise AssertionError (f"Invalid backend type { self .backend_type } " )
@@ -2246,16 +2263,19 @@ def split_optimizer_states(
2246
2263
# pyre-ignore
2247
2264
bucket_size = self .kv_zch_params .bucket_sizes [t ]
2248
2265
row_offset = table_offset
2249
- if sorted_id_tensor is None or sorted_id_tensor [t ].numel () == 0 :
2266
+ if not self .backend_return_whole_row and (
2267
+ sorted_id_tensor is None or sorted_id_tensor [t ].numel () == 0
2268
+ ):
2250
2269
opt_list .append (
2251
2270
torch .empty (0 , dtype = self .optimizer .dtype (), device = "cpu" )
2252
2271
# empty optimizer state for module initialization
2272
+ # which will NOT be used for cp loading
2253
2273
)
2254
2274
else :
2255
2275
if not self .enable_optimizer_offloading :
2256
2276
# convert global id back to local id, then linearize with table offset
2257
2277
local_id_tensor = (
2258
- sorted_id_tensor [t ]
2278
+ sorted_id_tensor [t ] # pyre-ignore[16]
2259
2279
- bucket_id_start * bucket_size
2260
2280
+ table_offset
2261
2281
)
@@ -2264,27 +2284,74 @@ def split_optimizer_states(
2264
2284
)
2265
2285
else :
2266
2286
row_offset = table_offset - (bucket_id_start * bucket_size )
2267
- # using KVTensorWrapper to query backend to avoid OOM memory, since
2268
- # backend will return both weight and optimizer in one tensor, read the whole tensor
2269
- # out could OOM CPU memory.
2270
- tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2271
- shape = [emb_height , optimizer_dim ],
2272
- dtype = dtype ,
2273
- row_offset = row_offset ,
2274
- snapshot_handle = snapshot_handle ,
2275
- sorted_indices = sorted_id_tensor [t ],
2276
- width_offset = pad4 (emb_dim ),
2277
- )
2278
- (
2279
- tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2280
- if self .backend_type == BackendType .SSD
2281
- else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2282
- )
2283
- opt_list .append (
2284
- self .get_offloaded_optimizer_states (
2285
- tensor_wrapper , sorted_id_tensor [t ].numel ()
2287
+ if self .backend_return_whole_row :
2288
+ # When backend returns whole row, the optimizer will be returned as PMT directly
2289
+ if (
2290
+ sorted_id_tensor [t ].size (0 ) == 0
2291
+ and self .local_weight_counts [t ] > 0
2292
+ ):
2293
+ logging .info (
2294
+ f"before opt PMT loading, resetting id tensor with { self .local_weight_counts [t ]} "
2295
+ )
2296
+ # pyre-ignore [16]
2297
+ sorted_id_tensor [t ] = torch .zeros (
2298
+ (self .local_weight_counts [t ], 1 ),
2299
+ device = torch .device ("cpu" ),
2300
+ dtype = torch .int64 ,
2301
+ )
2302
+ tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2303
+ shape = [
2304
+ (
2305
+ sorted_id_tensor [t ].size (0 )
2306
+ if sorted_id_tensor is not None
2307
+ and sorted_id_tensor [t ].size (0 ) > 0
2308
+ else emb_height
2309
+ ),
2310
+ optimizer_dim ,
2311
+ ],
2312
+ dtype = dtype ,
2313
+ row_offset = row_offset ,
2314
+ snapshot_handle = snapshot_handle ,
2315
+ sorted_indices = sorted_id_tensor [t ],
2316
+ width_offset = (
2317
+ self .metaheader_dim # metaheader is already padded so no need for pad4
2318
+ + pad4 (emb_dim )
2319
+ ),
2320
+ read_only = True , # optimizer written to DB with weights, so skip write here
2321
+ )
2322
+ (
2323
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2324
+ if self .backend_type == BackendType .SSD
2325
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2326
+ )
2327
+ opt_list .append (
2328
+ PartiallyMaterializedTensor (
2329
+ tensor_wrapper ,
2330
+ True if self .kv_zch_params else False ,
2331
+ )
2332
+ )
2333
+ else :
2334
+ # using KVTensorWrapper to query backend to avoid OOM memory, since
2335
+ # backend will return both weight and optimizer in one tensor, read the whole tensor
2336
+ # out could OOM CPU memory.
2337
+ tensor_wrapper = torch .classes .fbgemm .KVTensorWrapper (
2338
+ shape = [emb_height , optimizer_dim ],
2339
+ dtype = dtype ,
2340
+ row_offset = row_offset ,
2341
+ snapshot_handle = snapshot_handle ,
2342
+ sorted_indices = sorted_id_tensor [t ],
2343
+ width_offset = pad4 (emb_dim ),
2344
+ )
2345
+ (
2346
+ tensor_wrapper .set_embedding_rocks_dp_wrapper (self .ssd_db )
2347
+ if self .backend_type == BackendType .SSD
2348
+ else tensor_wrapper .set_dram_db_wrapper (self .ssd_db )
2349
+ )
2350
+ opt_list .append (
2351
+ self .get_offloaded_optimizer_states (
2352
+ tensor_wrapper , sorted_id_tensor [t ].numel ()
2353
+ )
2286
2354
)
2287
- )
2288
2355
table_offset += emb_height
2289
2356
logging .info (
2290
2357
f"KV ZCH tables split_optimizer_states query latency: { (time .time () - start_time ) * 1000 } ms, "
@@ -2513,7 +2580,7 @@ def split_embedding_weights(
2513
2580
and self .local_weight_counts [i ] > 0
2514
2581
):
2515
2582
logging .info (
2516
- f"resetting bucket id tensor with { self .local_weight_counts [i ]} "
2583
+ f"before weight PMT loading, resetting id tensor with { self .local_weight_counts [i ]} "
2517
2584
)
2518
2585
bucket_ascending_id_tensor = torch .zeros (
2519
2586
(self .local_weight_counts [i ], 1 ),
@@ -2539,7 +2606,19 @@ def split_embedding_weights(
2539
2606
if bucket_ascending_id_tensor is not None
2540
2607
else emb_height
2541
2608
),
2542
- emb_dim ,
2609
+ (
2610
+ (
2611
+ self .metaheader_dim # metaheader is already padded
2612
+ + pad4 (emb_dim )
2613
+ + pad4 (
2614
+ self .optimizer .state_size_dim (
2615
+ self .weights_precision .as_dtype ()
2616
+ )
2617
+ )
2618
+ )
2619
+ if self .backend_return_whole_row
2620
+ else emb_dim
2621
+ ),
2543
2622
],
2544
2623
dtype = dtype ,
2545
2624
row_offset = row_offset ,
@@ -2576,6 +2655,11 @@ def split_embedding_weights(
2576
2655
2577
2656
@torch .jit .ignore
2578
2657
def apply_state_dict (self ) -> None :
2658
+ if self .backend_return_whole_row :
2659
+ logging .info (
2660
+ "backend_return_whole_row is enabled, no need to apply_state_dict"
2661
+ )
2662
+ return
2579
2663
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
2580
2664
# Caller should call this function to apply the cached states to backend.
2581
2665
if self .load_state_dict is False :
@@ -2694,6 +2778,11 @@ def streaming_write_weight_and_id_per_table(
2694
2778
2695
2779
@torch .jit .ignore
2696
2780
def enable_load_state_dict_mode (self ) -> None :
2781
+ if self .backend_return_whole_row :
2782
+ logging .info (
2783
+ "backend_return_whole_row is enabled, no need to enable load_state_dict mode"
2784
+ )
2785
+ return
2697
2786
# Enable load state dict mode before loading checkpoint
2698
2787
if self .load_state_dict :
2699
2788
return
0 commit comments