Skip to content

Commit e6ed891

Browse files
Jianbo Liufacebook-github-bot
authored andcommitted
Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading (#4435)
Summary: # Context In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again. # This diff only contains frontend changes * added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used * added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op * added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors * by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs Differential Revision: D77666892
1 parent 3daa067 commit e6ed891

File tree

5 files changed

+546
-29
lines changed

5 files changed

+546
-29
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,19 @@ class KVZCHParams(NamedTuple):
106106
bucket_sizes: List[int] = []
107107
# enable optimizer offloading or not
108108
enable_optimizer_offloading: bool = False
109+
# when enabled, backend will return whole row(metaheader + weight + optimizer) instead of weight only
110+
# can only be enabled when enable_optimizer_offloading is enabled
111+
backend_return_whole_row: bool = False
109112
eviction_policy: Optional[EvictionPolicy] = None
110113

111114
def validate(self) -> None:
112115
assert len(self.bucket_offsets) == len(self.bucket_sizes), (
113116
"bucket_offsets and bucket_sizes must have the same length, "
114117
f"actual {self.bucket_offsets} vs {self.bucket_sizes}"
115118
)
119+
assert (
120+
not self.backend_return_whole_row or self.enable_optimizer_offloading
121+
), "backend_return_whole_row can only be enabled when enable_optimizer_offloading is enabled"
116122

117123

118124
class BackendType(enum.IntEnum):

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 115 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,30 @@ def __init__(
187187
self.kv_zch_params = kv_zch_params
188188
self.backend_type = backend_type
189189
self.enable_optimizer_offloading: bool = False
190+
self.backend_return_whole_row: bool = False
190191
if self.kv_zch_params:
191192
self.kv_zch_params.validate()
192193
self.enable_optimizer_offloading = (
193194
# pyre-ignore [16]
194195
self.kv_zch_params.enable_optimizer_offloading
195196
)
197+
self.backend_return_whole_row = (
198+
# pyre-ignore [16]
199+
self.kv_zch_params.backend_return_whole_row
200+
)
196201

197202
if self.enable_optimizer_offloading:
198203
logging.info("Optimizer state offloading is enabled")
204+
if self.backend_return_whole_row:
205+
assert (
206+
self.backend_type == BackendType.DRAM
207+
), f"Only DRAM backend supports backend_return_whole_row, but got {self.backend_type}"
208+
logging.info(
209+
"Backend will return whole row including metaheader, weight and optimizer for checkpoint"
210+
)
211+
212+
# TODO: the metaheader is 16 bytes fixed.
213+
self.metaheader_dim: int = 16 // (weights_precision.bit_rate() // 8)
199214

200215
self.pooling_mode = pooling_mode
201216
self.bounds_check_mode_int: int = bounds_check_mode.value
@@ -612,13 +627,14 @@ def __init__(
612627
logging.info(
613628
f"Logging DRAM offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB,"
614629
f"num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
615-
f"max_D={self.max_D}"
630+
f"max_D={self.max_D},"
616631
f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
617632
f"row_storage_bitwidth={weights_precision.bit_rate()},"
618633
f"self.cache_row_dim={self.cache_row_dim},"
619634
f"enable_optimizer_offloading={self.enable_optimizer_offloading},"
620635
f"feature_dims={self.feature_dims},"
621-
f"hash_size_cumsum={self.hash_size_cumsum}"
636+
f"hash_size_cumsum={self.hash_size_cumsum},"
637+
f"backend_return_whole_row={self.backend_return_whole_row}"
622638
)
623639
table_dims = (
624640
tensor_pad4(self.table_dims)
@@ -659,6 +675,7 @@ def __init__(
659675
if self.enable_optimizer_offloading
660676
else None
661677
), # hash_size_cumsum
678+
self.backend_return_whole_row, # backend_return_whole_row
662679
)
663680
else:
664681
raise AssertionError(f"Invalid backend type {self.backend_type}")
@@ -2246,16 +2263,19 @@ def split_optimizer_states(
22462263
# pyre-ignore
22472264
bucket_size = self.kv_zch_params.bucket_sizes[t]
22482265
row_offset = table_offset
2249-
if sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0:
2266+
if not self.backend_return_whole_row and (
2267+
sorted_id_tensor is None or sorted_id_tensor[t].numel() == 0
2268+
):
22502269
opt_list.append(
22512270
torch.empty(0, dtype=self.optimizer.dtype(), device="cpu")
22522271
# empty optimizer state for module initialization
2272+
# which will NOT be used for cp loading
22532273
)
22542274
else:
22552275
if not self.enable_optimizer_offloading:
22562276
# convert global id back to local id, then linearize with table offset
22572277
local_id_tensor = (
2258-
sorted_id_tensor[t]
2278+
sorted_id_tensor[t] # pyre-ignore[16]
22592279
- bucket_id_start * bucket_size
22602280
+ table_offset
22612281
)
@@ -2264,27 +2284,74 @@ def split_optimizer_states(
22642284
)
22652285
else:
22662286
row_offset = table_offset - (bucket_id_start * bucket_size)
2267-
# using KVTensorWrapper to query backend to avoid OOM memory, since
2268-
# backend will return both weight and optimizer in one tensor, read the whole tensor
2269-
# out could OOM CPU memory.
2270-
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2271-
shape=[emb_height, optimizer_dim],
2272-
dtype=dtype,
2273-
row_offset=row_offset,
2274-
snapshot_handle=snapshot_handle,
2275-
sorted_indices=sorted_id_tensor[t],
2276-
width_offset=pad4(emb_dim),
2277-
)
2278-
(
2279-
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2280-
if self.backend_type == BackendType.SSD
2281-
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2282-
)
2283-
opt_list.append(
2284-
self.get_offloaded_optimizer_states(
2285-
tensor_wrapper, sorted_id_tensor[t].numel()
2287+
if self.backend_return_whole_row:
2288+
# When backend returns whole row, the optimizer will be returned as PMT directly
2289+
if (
2290+
sorted_id_tensor[t].size(0) == 0
2291+
and self.local_weight_counts[t] > 0
2292+
):
2293+
logging.info(
2294+
f"before opt PMT loading, resetting id tensor with {self.local_weight_counts[t]}"
2295+
)
2296+
# pyre-ignore [16]
2297+
sorted_id_tensor[t] = torch.zeros(
2298+
(self.local_weight_counts[t], 1),
2299+
device=torch.device("cpu"),
2300+
dtype=torch.int64,
2301+
)
2302+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2303+
shape=[
2304+
(
2305+
sorted_id_tensor[t].size(0)
2306+
if sorted_id_tensor is not None
2307+
and sorted_id_tensor[t].size(0) > 0
2308+
else emb_height
2309+
),
2310+
optimizer_dim,
2311+
],
2312+
dtype=dtype,
2313+
row_offset=row_offset,
2314+
snapshot_handle=snapshot_handle,
2315+
sorted_indices=sorted_id_tensor[t],
2316+
width_offset=(
2317+
self.metaheader_dim # metaheader is already padded so no need for pad4
2318+
+ pad4(emb_dim)
2319+
),
2320+
read_only=True, # optimizer written to DB with weights, so skip write here
2321+
)
2322+
(
2323+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2324+
if self.backend_type == BackendType.SSD
2325+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2326+
)
2327+
opt_list.append(
2328+
PartiallyMaterializedTensor(
2329+
tensor_wrapper,
2330+
True if self.kv_zch_params else False,
2331+
)
2332+
)
2333+
else:
2334+
# using KVTensorWrapper to query backend to avoid OOM memory, since
2335+
# backend will return both weight and optimizer in one tensor, read the whole tensor
2336+
# out could OOM CPU memory.
2337+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
2338+
shape=[emb_height, optimizer_dim],
2339+
dtype=dtype,
2340+
row_offset=row_offset,
2341+
snapshot_handle=snapshot_handle,
2342+
sorted_indices=sorted_id_tensor[t],
2343+
width_offset=pad4(emb_dim),
2344+
)
2345+
(
2346+
tensor_wrapper.set_embedding_rocks_dp_wrapper(self.ssd_db)
2347+
if self.backend_type == BackendType.SSD
2348+
else tensor_wrapper.set_dram_db_wrapper(self.ssd_db)
2349+
)
2350+
opt_list.append(
2351+
self.get_offloaded_optimizer_states(
2352+
tensor_wrapper, sorted_id_tensor[t].numel()
2353+
)
22862354
)
2287-
)
22882355
table_offset += emb_height
22892356
logging.info(
22902357
f"KV ZCH tables split_optimizer_states query latency: {(time.time() - start_time) * 1000} ms, "
@@ -2513,7 +2580,7 @@ def split_embedding_weights(
25132580
and self.local_weight_counts[i] > 0
25142581
):
25152582
logging.info(
2516-
f"resetting bucket id tensor with {self.local_weight_counts[i]}"
2583+
f"before weight PMT loading, resetting id tensor with {self.local_weight_counts[i]}"
25172584
)
25182585
bucket_ascending_id_tensor = torch.zeros(
25192586
(self.local_weight_counts[i], 1),
@@ -2539,7 +2606,19 @@ def split_embedding_weights(
25392606
if bucket_ascending_id_tensor is not None
25402607
else emb_height
25412608
),
2542-
emb_dim,
2609+
(
2610+
(
2611+
self.metaheader_dim # metaheader is already padded
2612+
+ pad4(emb_dim)
2613+
+ pad4(
2614+
self.optimizer.state_size_dim(
2615+
self.weights_precision.as_dtype()
2616+
)
2617+
)
2618+
)
2619+
if self.backend_return_whole_row
2620+
else emb_dim
2621+
),
25432622
],
25442623
dtype=dtype,
25452624
row_offset=row_offset,
@@ -2576,6 +2655,11 @@ def split_embedding_weights(
25762655

25772656
@torch.jit.ignore
25782657
def apply_state_dict(self) -> None:
2658+
if self.backend_return_whole_row:
2659+
logging.info(
2660+
"backend_return_whole_row is enabled, no need to apply_state_dict"
2661+
)
2662+
return
25792663
# After checkpoint loading, the _cached_kvzch_data will be loaded from checkpoint.
25802664
# Caller should call this function to apply the cached states to backend.
25812665
if self.load_state_dict is False:
@@ -2694,6 +2778,11 @@ def streaming_write_weight_and_id_per_table(
26942778

26952779
@torch.jit.ignore
26962780
def enable_load_state_dict_mode(self) -> None:
2781+
if self.backend_return_whole_row:
2782+
logging.info(
2783+
"backend_return_whole_row is enabled, no need to enable load_state_dict mode"
2784+
)
2785+
return
26972786
# Enable load state dict mode before loading checkpoint
26982787
if self.load_state_dict:
26992788
return

fbgemm_gpu/test/tbe/ssd/kv_backend_test.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,3 +784,90 @@ def test_dram_kv_eviction(self) -> None:
784784
self.assertTrue(all(processed_counts >= shard_load))
785785
self.assertTrue(all(full_duration_ms > 0))
786786
self.assertTrue(all(exec_duration_ms >= 0))
787+
788+
@given(
789+
T=st.integers(min_value=2, max_value=10),
790+
D=st.integers(min_value=2, max_value=128),
791+
log_E=st.integers(min_value=2, max_value=3),
792+
weights_precision=st.sampled_from([SparseType.FP32, SparseType.FP16]),
793+
enable_l2=st.sampled_from([True, False]),
794+
)
795+
@settings(**default_settings)
796+
def test_dram_enable_backend_return_whole_row(
797+
self,
798+
T: int,
799+
D: int,
800+
log_E: int,
801+
weights_precision: SparseType,
802+
enable_l2: bool,
803+
) -> None:
804+
kv_zch_params = KVZCHParams(
805+
enable_optimizer_offloading=True,
806+
backend_return_whole_row=True, # whole row will be returned to KVT
807+
)
808+
metaheader_dim: int = 16 // (weights_precision.bit_rate() // 8)
809+
opt_dim: int = 4 // (weights_precision.bit_rate() // 8)
810+
emb, Es, Ds = self.generate_fbgemm_kv_tbe(
811+
T,
812+
D,
813+
log_E,
814+
weights_precision,
815+
mixed=True,
816+
enable_l2=enable_l2,
817+
kv_zch_params=kv_zch_params,
818+
backend_type=BackendType.DRAM,
819+
)
820+
dtype = weights_precision.as_dtype()
821+
row_offset = 0
822+
max_D = max(Ds)
823+
N = 2
824+
825+
for E, D in zip(Es, Ds):
826+
# create random index tensor with size N, valued from [0, N-1] unordered
827+
indices = torch.randperm(N)
828+
# insert the weights with the corresponding indices into the table
829+
# which will also populate the metaheader with weight_id at front
830+
weights = torch.arange(N * D, dtype=dtype).view(N, D)
831+
padded_weights = torch.nn.functional.pad(weights, (0, max_D - D))
832+
# emb.ssd_db.set_kv_to_storage(indices + row_offset, padded_weights)
833+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
834+
shape=[E, D], # only write D from weights
835+
dtype=dtype,
836+
row_offset=row_offset,
837+
snapshot_handle=None,
838+
)
839+
tensor_wrapper.set_dram_db_wrapper(emb.ssd_db)
840+
tensor_wrapper.set_weights_and_ids(padded_weights, indices)
841+
842+
# reset KVT's shape to full dim to get whole row
843+
tensor_wrapper = torch.classes.fbgemm.KVTensorWrapper(
844+
shape=[E, metaheader_dim + pad4(D) + pad4(opt_dim)],
845+
dtype=dtype,
846+
row_offset=row_offset,
847+
snapshot_handle=None,
848+
)
849+
tensor_wrapper.set_dram_db_wrapper(emb.ssd_db)
850+
851+
# Call narrow which should fetch the whole row
852+
narrowed = tensor_wrapper.narrow(0, 0, N)
853+
opt_offset = metaheader_dim + pad4(D)
854+
855+
for i in range(N):
856+
# Check if the id matches
857+
torch.testing.assert_close(
858+
narrowed[i, : metaheader_dim // 2].view(torch.int64),
859+
torch.tensor([i + row_offset], dtype=torch.int64),
860+
)
861+
862+
# Check if weight matches the one passed in with weights
863+
torch.testing.assert_close(
864+
narrowed[i, metaheader_dim:opt_offset],
865+
weights[indices.tolist().index(i)],
866+
)
867+
868+
# The trailing opt part should all be init'ed with 0s
869+
torch.testing.assert_close(
870+
narrowed[:, opt_offset : opt_offset + opt_dim],
871+
torch.zeros(N, opt_dim, dtype=dtype),
872+
)
873+
row_offset += E

0 commit comments

Comments
 (0)