Skip to content

Commit 0aab31f

Browse files
Jianbo Liufacebook-github-bot
authored andcommitted
Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading (#4435)
Summary: X-link: facebookresearch/FBGEMM#1500 X-link: pytorch/torchrec#3153 Pull Request resolved: #4435 # Context In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again. # This diff only contains frontend changes * added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used * added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op * added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors * by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs Reviewed By: emlin Differential Revision: D77666892 Privacy Context Container: L1138451 fbshipit-source-id: b0ca5f0f880ede1a803f77d0d520abb3356a0c8d
1 parent 76c16e5 commit 0aab31f

File tree

5 files changed

+564
-32
lines changed

5 files changed

+564
-32
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class EvictionPolicy(NamedTuple):
9595
# wait at least # seconds before trigger next round of eviction, if last finished eviction is sufficient
9696
60
9797
)
98+
meta_header_lens: Optional[List[int]] = None # metaheader length for each table
9899

99100
def validate(self) -> None:
100101
assert self.eviction_trigger_mode in [0, 1, 2, 3], (
@@ -171,15 +172,20 @@ class KVZCHParams(NamedTuple):
171172
bucket_sizes: List[int] = []
172173
# enable optimizer offloading or not
173174
enable_optimizer_offloading: bool = False
174-
eviction_policy: Optional[EvictionPolicy] = None
175+
# when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only
176+
# can only be enabled when enable_optimizer_offloading is enabled
177+
backend_return_whole_row: bool = False
178+
eviction_policy: EvictionPolicy = EvictionPolicy()
175179

176180
def validate(self) -> None:
177181
assert len(self.bucket_offsets) == len(self.bucket_sizes), (
178182
"bucket_offsets and bucket_sizes must have the same length, "
179183
f"actual {self.bucket_offsets} vs {self.bucket_sizes}"
180184
)
181-
if self.eviction_policy is not None:
182-
self.eviction_policy.validate()
185+
self.eviction_policy.validate()
186+
assert (
187+
not self.backend_return_whole_row or self.enable_optimizer_offloading
188+
), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"
183189

184190

185191
class BackendType(enum.IntEnum):

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 122 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,27 @@ def __init__(
190190
self.kv_zch_params = kv_zch_params
191191
self.backend_type = backend_type
192192
self.enable_optimizer_offloading: bool = False
193+
self.backend_return_whole_row: bool = False
193194
if self.kv_zch_params:
194195
self.kv_zch_params.validate()
195196
self.enable_optimizer_offloading = (
196197
# pyre-ignore [16]
197198
self.kv_zch_params.enable_optimizer_offloading
198199
)
200+
self.backend_return_whole_row = (
201+
# pyre-ignore [16]
202+
self.kv_zch_params.backend_return_whole_row
203+
)
199204

200205
if self.enable_optimizer_offloading:
201206
logging.info("Optimizer state offloading is enabled")
207+
if self.backend_return_whole_row:
208+
assert (
209+
self.backend_type == BackendType.DRAM
210+
), f"Only DRAM backend supports backend_return_whole_row, but got {self.backend_type}"
211+
logging.info(
212+
"Backend will return whole row including metaheader, weight and optimizer for checkpoint"
213+
)
202214

203215
self.pooling_mode = pooling_mode
204216
self.bounds_check_mode_int: int = bounds_check_mode.value
@@ -625,13 +637,14 @@ def __init__(
625637
logging.info(
626638
f"Logging DRAM offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB,"
627639
f"num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
628-
f"max_D={self.max_D}"
640+
f"max_D={self.max_D},"
629641
f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
630642
f"row_storage_bitwidth={weights_precision.bit_rate()},"
631643
f"self.cache_row_dim={self.cache_row_dim},"
632644
f"enable_optimizer_offloading={self.enable_optimizer_offloading},"
633645
f"feature_dims={self.feature_dims},"
634-
f"hash_size_cumsum={self.hash_size_cumsum}"
646+
f"hash_size_cumsum={self.hash_size_cumsum},"
647+
f"backend_return_whole_row={self.backend_return_whole_row}"
635648
)
636649
table_dims = (
637650
tensor_pad4(self.table_dims)
@@ -672,6 +685,7 @@ def __init__(
672685
if self.enable_optimizer_offloading
673686
else None
674687
), # hash_size_cumsum
688+
self.backend_return_whole_row, # backend_return_whole_row
675689
)
676690
else:
677691
raise AssertionError(f"Invalid backend type {self.backend_type}")
@@ -2282,16 +2296,19 @@ def split_optimizer_states(
22822296
# pyre-ignore
22832297
bucket_size = self.kv_zch_params.bucket_sizes[t]
22842298
row_offset = table_offset
2285-
if sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0:
2299+
if not self.backend_return_whole_row and (
2300+
sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0
2301+
):
22862302
opt_list.append(
22872303
torch.empty(0, dtype=self.optimizer.dtype(), device="cpu")
22882304
# empty optimizer state for module initialization
2305+
# which will NOT be used for cp loading
22892306
)
22902307
else:
22912308
if not self.enable_optimizer_offloading:
22922309
# convert global id back to local id, then linearize with table offset
22932310
local_id_tensor = (
2294-
sorted_id_tensor[t]
2311+
sorted_id_tensor[t] # pyre-ignore[16]
22952312
- bucket_id_start * bucket_size
22962313
+ table_offset
22972314
)
@@ -2300,27 +2317,79 @@ def split_optimizer_states(
23002317
)
23012318
else:
23022319
row_offset = table_offset - (bucket_id_start * bucket_size)
2303-
# using KVTensorWrapper to query backend to avoid OOM memory, since
2304-
# backend will return both weight and optimizer in one tensor, read the whole tensor
2305-
# out could OOM CPU memory.
2306-
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2307-
shape=[emb_height, optimizer_dim],
2308-
dtype=dtype,
2309-
row_offset=row_offset,
2310-
snapshot_handle=snapshot_handle,
2311-
sorted_indices=sorted_id_tensor[t],
2312-
width_offset=pad4(emb_dim),
2313-
)
2314-
(
2315-
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2316-
if self.backend_type == BackendType.SSD
2317-
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2318-
)
2319-
opt_list.append(
2320-
self.get_offloaded_optimizer_states(
2321-
tensor_wrapper, sorted_id_tensor[t].numel()
2320+
if self.backend_return_whole_row:
2321+
# When backend returns whole row, the optimizer will be returned as PMT directly
2322+
if (
2323+
sorted_id_tensor[t].size(0) == 0
2324+
and self.local_weight_counts[t] > 0
2325+
):
2326+
logging.info(
2327+
f"before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
2328+
)
2329+
# pyre-ignore [16]
2330+
sorted_id_tensor[t] = torch.zeros(
2331+
(self.local_weight_counts[t], 1),
2332+
device=torch.device("cpu"),
2333+
dtype=torch.int64,
2334+
)
2335+
2336+
metaheader_dim = (
2337+
# pyre-ignore[16]
2338+
self.kv_zch_params.eviction_policy.meta_header_lens[t]
2339+
)
2340+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2341+
shape=[
2342+
(
2343+
sorted_id_tensor[t].size(0)
2344+
if sorted_id_tensor is not None
2345+
and sorted_id_tensor[t].size(0) > 0
2346+
else emb_height
2347+
),
2348+
optimizer_dim,
2349+
],
2350+
dtype=dtype,
2351+
row_offset=row_offset,
2352+
snapshot_handle=snapshot_handle,
2353+
sorted_indices=sorted_id_tensor[t],
2354+
width_offset=(
2355+
metaheader_dim # metaheader is already padded so no need for pad4
2356+
+ pad4(emb_dim)
2357+
),
2358+
read_only=True, # optimizer written to DB with weights, so skip write here
2359+
)
2360+
(
2361+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2362+
if self.backend_type == BackendType.SSD
2363+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2364+
)
2365+
opt_list.append(
2366+
PartiallyMaterializedTensor(
2367+
tensor_wrapper,
2368+
True if self.kv_zch_params else False,
2369+
)
2370+
)
2371+
else:
2372+
# using KVTensorWrapper to query backend to avoid OOM memory, since
2373+
# backend will return both weight and optimizer in one tensor, read the whole tensor
2374+
# out could OOM CPU memory.
2375+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2376+
shape=[emb_height, optimizer_dim],
2377+
dtype=dtype,
2378+
row_offset=row_offset,
2379+
snapshot_handle=snapshot_handle,
2380+
sorted_indices=sorted_id_tensor[t],
2381+
width_offset=pad4(emb_dim),
2382+
)
2383+
(
2384+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2385+
if self.backend_type == BackendType.SSD
2386+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2387+
)
2388+
opt_list.append(
2389+
self.get_offloaded_optimizer_states(
2390+
tensor_wrapper, sorted_id_tensor[t].numel()
2391+
)
23222392
)
2323-
)
23242393
table_offset += emb_height
23252394
logging.info(
23262395
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
@@ -2515,10 +2584,15 @@ def split_embedding_weights(
25152584
bucket_ascending_id_tensor = None
25162585
bucket_t = None
25172586
row_offset = table_offset
2587+
metaheader_dim = 0
25182588
if self.kv_zch_params:
25192589
bucket_id_start, bucket_id_end = self.kv_zch_params.bucket_offsets[i]
25202590
# pyre-ignore
25212591
bucket_size = self.kv_zch_params.bucket_sizes[i]
2592+
metaheader_dim = (
2593+
# pyre-ignore[16]
2594+
self.kv_zch_params.eviction_policy.meta_header_lens[i]
2595+
)
25222596

25232597
# linearize with table offset
25242598
table_input_id_start = table_offset
@@ -2548,7 +2622,7 @@ def split_embedding_weights(
25482622
and self.local_weight_counts[i] > 0
25492623
):
25502624
logging.info(
2551-
f"resetting bucket id tensor with {self.local_weight_counts[i]}"
2625+
f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}"
25522626
)
25532627
bucket_ascending_id_tensor = torch.zeros(
25542628
(self.local_weight_counts[i], 1),
@@ -2574,7 +2648,19 @@ def split_embedding_weights(
25742648
if bucket_ascending_id_tensor is not None
25752649
else emb_height
25762650
),
2577-
emb_dim,
2651+
(
2652+
(
2653+
metaheader_dim # metaheader is already padded
2654+
+ pad4(emb_dim)
2655+
+ pad4(
2656+
self.optimizer.state_size_dim(
2657+
self.weights_precision.as_dtype()
2658+
)
2659+
)
2660+
)
2661+
if self.backend_return_whole_row
2662+
else emb_dim
2663+
),
25782664
],
25792665
dtype=dtype,
25802666
row_offset=row_offset,
@@ -2611,6 +2697,11 @@ def split_embedding_weights(
26112697

26122698
@torch.jit.ignore
26132699
def apply_state_dict(self) -> None:
2700+
if self.backend_return_whole_row:
2701+
logging.info(
2702+
"backend_return_whole_row is enabled, no need to apply_state_dict"
2703+
)
2704+
return
26142705
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
26152706
# Caller should call this function to apply the cached states to backend.
26162707
if self.load_state_dict is False:
@@ -2729,6 +2820,11 @@ def streaming_write_weight_and_id_per_table(
27292820

27302821
@torch.jit.ignore
27312822
def enable_load_state_dict_mode(self) -> None:
2823+
if self.backend_return_whole_row:
2824+
logging.info(
2825+
"backend_return_whole_row is enabled, no need to enable load_state_dict mode"
2826+
)
2827+
return
27322828
# Enable load state dict mode before loading checkpoint
27332829
if self.load_state_dict:
27342830
return

fbgemm_gpu/test/tbe/ssd/kv_backend_test.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,3 +784,90 @@ def test_dram_kv_eviction(self) -> None:
784784
self.assertTrue(all(processed_counts >= shard_load))
785785
self.assertTrue(all(full_duration_ms > 0))
786786
self.assertTrue(all(exec_duration_ms >= 0))
787+
788+
@given(
789+
T=st.integers(min_value=2, max_value=10),
790+
D=st.integers(min_value=2, max_value=128),
791+
log_E=st.integers(min_value=2, max_value=3),
792+
weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]),
793+
enable_l2=st.sampled_from([True, False]),
794+
)
795+
@settings(**default_settings)
796+
def test_dram_enable_backend_return_whole_row(
797+
self,
798+
T: int,
799+
D: int,
800+
log_E: int,
801+
weights_precision: SparseType,
802+
enable_l2: bool,
803+
) -> None:
804+
kv_zch_params = KVZCHParams(
805+
enable_optimizer_offloading=True,
806+
backend_return_whole_row=True, # whole row will be returned to KVT
807+
)
808+
metaheader_dim: int = 16 // (weights_precision.bit_rate() // 8)
809+
opt_dim: int = 4 // (weights_precision.bit_rate() // 8)
810+
emb, Es, Ds = self.generate_fbgemm_kv_tbe(
811+
T,
812+
D,
813+
log_E,
814+
weights_precision,
815+
mixed=True,
816+
enable_l2=enable_l2,
817+
kv_zch_params=kv_zch_params,
818+
backend_type=BackendType.DRAM,
819+
)
820+
dtype = weights_precision.as_dtype()
821+
row_offset = 0
822+
max_D = max(Ds)
823+
N = 2
824+
825+
for E, D in zip(Es, Ds):
826+
# create random index tensor with size N, valued from [0, N-1] unordered
827+
indices = torch.randperm(N)
828+
# insert the weights with the corresponding indices into the table
829+
# which will also populate the metaheader with weight_id at front
830+
weights = torch.arange(N * D, dtype=dtype).view(N, D)
831+
padded_weights = torch.nn.functional.pad(weights, (0, max_D - D))
832+
# emb.ssd_db.set_kv_to_storage(indices + row_offset, padded_weights)
833+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
834+
shape=[E, D], # only write D from weights
835+
dtype=dtype,
836+
row_offset=row_offset,
837+
snapshot_handle=None,
838+
)
839+
tensor_wrapper.set_dram_db_wrapper(emb.ssd_db)
840+
tensor_wrapper.set_weights_and_ids(padded_weights, indices)
841+
842+
# reset KVT's shape to full dim to get whole row
843+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
844+
shape=[E, metaheader_dim + pad4(D) + pad4(opt_dim)],
845+
dtype=dtype,
846+
row_offset=row_offset,
847+
snapshot_handle=None,
848+
)
849+
tensor_wrapper.set_dram_db_wrapper(emb.ssd_db)
850+
851+
# Call narrow which should fetch the whole row
852+
narrowed = tensor_wrapper.narrow(0, 0, N)
853+
opt_offset = metaheader_dim + pad4(D)
854+
855+
for i in range(N):
856+
# Check if the id matches
857+
torch.testing.assert_close(
858+
narrowed[i, : metaheader_dim // 2].view(torch.int64),
859+
torch.tensor([i + row_offset], dtype=torch.int64),
860+
)
861+
862+
# Check if weight matches the one passed in with weights
863+
torch.testing.assert_close(
864+
narrowed[i, metaheader_dim:opt_offset],
865+
weights[indices.tolist().index(i)],
866+
)
867+
868+
# The trailing opt part should all be init'ed with 0s
869+
torch.testing.assert_close(
870+
narrowed[:, opt_offset : opt_offset + opt_dim],
871+
torch.zeros(N, opt_dim, dtype=dtype),
872+
)
873+
row_offset += E

0 commit comments

Comments
 (0)