Skip to content

Commit cc754fc

Browse files
Jianbo Liufacebook-github-bot
authored andcommitted
Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading (#4435)
Summary: X-link: facebookresearch/FBGEMM#1500 X-link: pytorch/torchrec#3153 Pull Request resolved: #4435 # Context In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again. # This diff only contains frontend changes * added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used * added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op * added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors * by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs Reviewed By: emlin Differential Revision: D77666892 Privacy Context Container: L1138451
1 parent 5d24e24 commit cc754fc

File tree

5 files changed

+547
-29
lines changed

5 files changed

+547
-29
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,19 @@ class KVZCHParams(NamedTuple):
106106
bucket_sizes: List[int] = []
107107
# enable optimizer offloading or not
108108
enable_optimizer_offloading: bool = False
109+
# when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only
110+
# can only be enabled when enable_optimizer_offloading is enabled
111+
backend_return_whole_row: bool = False
109112
eviction_policy: Optional[EvictionPolicy] = None
110113

111114
def validate(self) -> None:
112115
assert len(self.bucket_offsets) == len(self.bucket_sizes), (
113116
"bucket_offsets and bucket_sizes must have the same length, "
114117
f"actual {self.bucket_offsets} vs {self.bucket_sizes}"
115118
)
119+
assert (
120+
not self.backend_return_whole_row or self.enable_optimizer_offloading
121+
), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"
116122

117123

118124
class BackendType(enum.IntEnum):

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 116 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -189,15 +189,31 @@ def __init__(
189189
self.kv_zch_params = kv_zch_params
190190
self.backend_type = backend_type
191191
self.enable_optimizer_offloading: bool = False
192+
self.backend_return_whole_row: bool = False
192193
if self.kv_zch_params:
193194
self.kv_zch_params.validate()
194195
self.enable_optimizer_offloading = (
195196
# pyre-ignore [16]
196197
self.kv_zch_params.enable_optimizer_offloading
197198
)
199+
self.backend_return_whole_row = (
200+
# pyre-ignore [16]
201+
self.kv_zch_params.backend_return_whole_row
202+
)
198203

199204
if self.enable_optimizer_offloading:
200205
logging.info("Optimizer state offloading is enabled")
206+
if self.backend_return_whole_row:
207+
assert (
208+
self.backend_type == BackendType.DRAM
209+
), f"Only DRAM backend supports backend_return_whole_row, but got {self.backend_type}"
210+
logging.info(
211+
"Backend will return whole row including metaheader, weight and optimizer for checkpoint"
212+
)
213+
214+
# same calculation as `virtual_table_eviction_policy.get_meta_header_len()`
215+
# in `torchrec/modules/embedding_configs.py
216+
self.metaheader_dim: int = 16 // (weights_precision.bit_rate() // 8)
201217

202218
self.pooling_mode = pooling_mode
203219
self.bounds_check_mode_int: int = bounds_check_mode.value
@@ -615,13 +631,14 @@ def __init__(
615631
logging.info(
616632
f"Logging DRAM offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB,"
617633
f"num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
618-
f"max_D={self.max_D}"
634+
f"max_D={self.max_D},"
619635
f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
620636
f"row_storage_bitwidth={weights_precision.bit_rate()},"
621637
f"self.cache_row_dim={self.cache_row_dim},"
622638
f"enable_optimizer_offloading={self.enable_optimizer_offloading},"
623639
f"feature_dims={self.feature_dims},"
624-
f"hash_size_cumsum={self.hash_size_cumsum}"
640+
f"hash_size_cumsum={self.hash_size_cumsum},"
641+
f"backend_return_whole_row={self.backend_return_whole_row}"
625642
)
626643
table_dims = (
627644
tensor_pad4(self.table_dims)
@@ -662,6 +679,7 @@ def __init__(
662679
if self.enable_optimizer_offloading
663680
else None
664681
), # hash_size_cumsum
682+
self.backend_return_whole_row, # backend_return_whole_row
665683
)
666684
else:
667685
raise AssertionError(f"Invalid backend type {self.backend_type}")
@@ -2249,16 +2267,19 @@ def split_optimizer_states(
22492267
# pyre-ignore
22502268
bucket_size = self.kv_zch_params.bucket_sizes[t]
22512269
row_offset = table_offset
2252-
if sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0:
2270+
if not self.backend_return_whole_row and (
2271+
sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0
2272+
):
22532273
opt_list.append(
22542274
torch.empty(0, dtype=self.optimizer.dtype(), device="cpu")
22552275
# empty optimizer state for module initialization
2276+
# which will NOT be used for cp loading
22562277
)
22572278
else:
22582279
if not self.enable_optimizer_offloading:
22592280
# convert global id back to local id, then linearize with table offset
22602281
local_id_tensor = (
2261-
sorted_id_tensor[t]
2282+
sorted_id_tensor[t] # pyre-ignore[16]
22622283
- bucket_id_start * bucket_size
22632284
+ table_offset
22642285
)
@@ -2267,27 +2288,74 @@ def split_optimizer_states(
22672288
)
22682289
else:
22692290
row_offset = table_offset - (bucket_id_start * bucket_size)
2270-
# using KVTensorWrapper to query backend to avoid OOM memory, since
2271-
# backend will return both weight and optimizer in one tensor, read the whole tensor
2272-
# out could OOM CPU memory.
2273-
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2274-
shape=[emb_height, optimizer_dim],
2275-
dtype=dtype,
2276-
row_offset=row_offset,
2277-
snapshot_handle=snapshot_handle,
2278-
sorted_indices=sorted_id_tensor[t],
2279-
width_offset=pad4(emb_dim),
2280-
)
2281-
(
2282-
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2283-
if self.backend_type == BackendType.SSD
2284-
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2285-
)
2286-
opt_list.append(
2287-
self.get_offloaded_optimizer_states(
2288-
tensor_wrapper, sorted_id_tensor[t].numel()
2291+
if self.backend_return_whole_row:
2292+
# When backend returns whole row, the optimizer will be returned as PMT directly
2293+
if (
2294+
sorted_id_tensor[t].size(0) == 0
2295+
and self.local_weight_counts[t] > 0
2296+
):
2297+
logging.info(
2298+
f"before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
2299+
)
2300+
# pyre-ignore [16]
2301+
sorted_id_tensor[t] = torch.zeros(
2302+
(self.local_weight_counts[t], 1),
2303+
device=torch.device("cpu"),
2304+
dtype=torch.int64,
2305+
)
2306+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2307+
shape=[
2308+
(
2309+
sorted_id_tensor[t].size(0)
2310+
if sorted_id_tensor is not None
2311+
and sorted_id_tensor[t].size(0) > 0
2312+
else emb_height
2313+
),
2314+
optimizer_dim,
2315+
],
2316+
dtype=dtype,
2317+
row_offset=row_offset,
2318+
snapshot_handle=snapshot_handle,
2319+
sorted_indices=sorted_id_tensor[t],
2320+
width_offset=(
2321+
self.metaheader_dim # metaheader is already padded so no need for pad4
2322+
+ pad4(emb_dim)
2323+
),
2324+
read_only=True, # optimizer written to DB with weights, so skip write here
2325+
)
2326+
(
2327+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2328+
if self.backend_type == BackendType.SSD
2329+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2330+
)
2331+
opt_list.append(
2332+
PartiallyMaterializedTensor(
2333+
tensor_wrapper,
2334+
True if self.kv_zch_params else False,
2335+
)
2336+
)
2337+
else:
2338+
# using KVTensorWrapper to query backend to avoid OOM memory, since
2339+
# backend will return both weight and optimizer in one tensor, read the whole tensor
2340+
# out could OOM CPU memory.
2341+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2342+
shape=[emb_height, optimizer_dim],
2343+
dtype=dtype,
2344+
row_offset=row_offset,
2345+
snapshot_handle=snapshot_handle,
2346+
sorted_indices=sorted_id_tensor[t],
2347+
width_offset=pad4(emb_dim),
2348+
)
2349+
(
2350+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2351+
if self.backend_type == BackendType.SSD
2352+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2353+
)
2354+
opt_list.append(
2355+
self.get_offloaded_optimizer_states(
2356+
tensor_wrapper, sorted_id_tensor[t].numel()
2357+
)
22892358
)
2290-
)
22912359
table_offset += emb_height
22922360
logging.info(
22932361
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
@@ -2515,7 +2583,7 @@ def split_embedding_weights(
25152583
and self.local_weight_counts[i] > 0
25162584
):
25172585
logging.info(
2518-
f"resetting bucket id tensor with {self.local_weight_counts[i]}"
2586+
f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}"
25192587
)
25202588
bucket_ascending_id_tensor = torch.zeros(
25212589
(self.local_weight_counts[i], 1),
@@ -2541,7 +2609,19 @@ def split_embedding_weights(
25412609
if bucket_ascending_id_tensor is not None
25422610
else emb_height
25432611
),
2544-
emb_dim,
2612+
(
2613+
(
2614+
self.metaheader_dim # metaheader is already padded
2615+
+ pad4(emb_dim)
2616+
+ pad4(
2617+
self.optimizer.state_size_dim(
2618+
self.weights_precision.as_dtype()
2619+
)
2620+
)
2621+
)
2622+
if self.backend_return_whole_row
2623+
else emb_dim
2624+
),
25452625
],
25462626
dtype=dtype,
25472627
row_offset=row_offset,
@@ -2578,6 +2658,11 @@ def split_embedding_weights(
25782658

25792659
@torch.jit.ignore
25802660
def apply_state_dict(self) -> None:
2661+
if self.backend_return_whole_row:
2662+
logging.info(
2663+
"backend_return_whole_row is enabled, no need to apply_state_dict"
2664+
)
2665+
return
25812666
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
25822667
# Caller should call this function to apply the cached states to backend.
25832668
if self.load_state_dict is False:
@@ -2696,6 +2781,11 @@ def streaming_write_weight_and_id_per_table(
26962781

26972782
@torch.jit.ignore
26982783
def enable_load_state_dict_mode(self) -> None:
2784+
if self.backend_return_whole_row:
2785+
logging.info(
2786+
"backend_return_whole_row is enabled, no need to enable load_state_dict mode"
2787+
)
2788+
return
26992789
# Enable load state dict mode before loading checkpoint
27002790
if self.load_state_dict:
27012791
return

fbgemm_gpu/test/tbe/ssd/kv_backend_test.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,3 +784,90 @@ def test_dram_kv_eviction(self) -> None:
784784
self.assertTrue(all(processed_counts >= shard_load))
785785
self.assertTrue(all(full_duration_ms > 0))
786786
self.assertTrue(all(exec_duration_ms >= 0))
787+
788+
@given(
789+
T=st.integers(min_value=2, max_value=10),
790+
D=st.integers(min_value=2, max_value=128),
791+
log_E=st.integers(min_value=2, max_value=3),
792+
weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]),
793+
enable_l2=st.sampled_from([True, False]),
794+
)
795+
@settings(**default_settings)
796+
def test_dram_enable_backend_return_whole_row(
797+
self,
798+
T: int,
799+
D: int,
800+
log_E: int,
801+
weights_precision: SparseType,
802+
enable_l2: bool,
803+
) -> None:
804+
kv_zch_params = KVZCHParams(
805+
enable_optimizer_offloading=True,
806+
backend_return_whole_row=True, # whole row will be returned to KVT
807+
)
808+
metaheader_dim: int = 16 // (weights_precision.bit_rate() // 8)
809+
opt_dim: int = 4 // (weights_precision.bit_rate() // 8)
810+
emb, Es, Ds = self.generate_fbgemm_kv_tbe(
811+
T,
812+
D,
813+
log_E,
814+
weights_precision,
815+
mixed=True,
816+
enable_l2=enable_l2,
817+
kv_zch_params=kv_zch_params,
818+
backend_type=BackendType.DRAM,
819+
)
820+
dtype = weights_precision.as_dtype()
821+
row_offset = 0
822+
max_D = max(Ds)
823+
N = 2
824+
825+
for E, D in zip(Es, Ds):
826+
# create random index tensor with size N, valued from [0, N-1] unordered
827+
indices = torch.randperm(N)
828+
# insert the weights with the corresponding indices into the table
829+
# which will also populate the metaheader with weight_id at front
830+
weights = torch.arange(N * D, dtype=dtype).view(N, D)
831+
padded_weights = torch.nn.functional.pad(weights, (0, max_D - D))
832+
# emb.ssd_db.set_kv_to_storage(indices + row_offset, padded_weights)
833+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
834+
shape=[E, D], # only write D from weights
835+
dtype=dtype,
836+
row_offset=row_offset,
837+
snapshot_handle=None,
838+
)
839+
tensor_wrapper.set_dram_db_wrapper(emb.ssd_db)
840+
tensor_wrapper.set_weights_and_ids(padded_weights, indices)
841+
842+
# reset KVT's shape to full dim to get whole row
843+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
844+
shape=[E, metaheader_dim + pad4(D) + pad4(opt_dim)],
845+
dtype=dtype,
846+
row_offset=row_offset,
847+
snapshot_handle=None,
848+
)
849+
tensor_wrapper.set_dram_db_wrapper(emb.ssd_db)
850+
851+
# Call narrow which should fetch the whole row
852+
narrowed = tensor_wrapper.narrow(0, 0, N)
853+
opt_offset = metaheader_dim + pad4(D)
854+
855+
for i in range(N):
856+
# Check if the id matches
857+
torch.testing.assert_close(
858+
narrowed[i, : metaheader_dim // 2].view(torch.int64),
859+
torch.tensor([i + row_offset], dtype=torch.int64),
860+
)
861+
862+
# Check if weight matches the one passed in with weights
863+
torch.testing.assert_close(
864+
narrowed[i, metaheader_dim:opt_offset],
865+
weights[indices.tolist().index(i)],
866+
)
867+
868+
# The trailing opt part should all be init'ed with 0s
869+
torch.testing.assert_close(
870+
narrowed[:, opt_offset : opt_offset + opt_dim],
871+
torch.zeros(N, opt_dim, dtype=dtype),
872+
)
873+
row_offset += E

0 commit comments

Comments
 (0)