Skip to content

Commit cba04eb

Browse files
Jianbo Liufacebook-github-bot
authored andcommitted
Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading (#3153)
Summary: X-link: facebookresearch/FBGEMM#1500 Pull Request resolved: #3153 X-link: pytorch/FBGEMM#4435 # Context In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again. # This diff only contains frontend changes * added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used * added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op * added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors * by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs Reviewed By: emlin Differential Revision: D77666892 Privacy Context Container: L1138451 fbshipit-source-id: b0ca5f0f880ede1a803f77d0d520abb3356a0c8d
1 parent d599766 commit cba04eb

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def _populate_zero_collision_tbe_params(
209209
tbe_params: Dict[str, Any],
210210
sharded_local_buckets: List[Tuple[int, int, int]],
211211
config: GroupedEmbeddingConfig,
212+
backend_type: BackendType,
212213
) -> None:
213214
"""
214215
Construct Zero Collision TBE params from config and fused params dict.
@@ -220,11 +221,15 @@ def _populate_zero_collision_tbe_params(
220221
bucket_sizes: List[int] = [size for _, _, size in sharded_local_buckets]
221222

222223
enabled = False
223-
for table in config.embedding_tables:
224-
if table.virtual_table_eviction_policy is not None and not isinstance(
225-
table.virtual_table_eviction_policy, NoEvictionPolicy
226-
):
227-
enabled = True
224+
meta_header_lens = [0] * len(config.embedding_tables)
225+
for i, table in enumerate(config.embedding_tables):
226+
# virtual_table_eviction_policy won't be None in reality: https://fburl.com/code/864a0w0f
227+
if table.virtual_table_eviction_policy is not None:
228+
meta_header_lens[i] = (
229+
table.virtual_table_eviction_policy.get_meta_header_len()
230+
)
231+
if not isinstance(table.virtual_table_eviction_policy, NoEvictionPolicy):
232+
enabled = True
228233
if enabled:
229234
counter_thresholds = [0] * len(config.embedding_tables)
230235
ttls_in_mins = [0] * len(config.embedding_tables)
@@ -283,14 +288,16 @@ def _populate_zero_collision_tbe_params(
283288
ttls_in_mins=ttls_in_mins,
284289
counter_decay_rates=counter_decay_rates,
285290
l2_weight_thresholds=l2_weight_thresholds,
291+
meta_header_lens=meta_header_lens,
286292
)
287293
else:
288-
eviction_policy = None
294+
eviction_policy = EvictionPolicy(meta_header_lens=meta_header_lens)
289295

290296
tbe_params["kv_zch_params"] = KVZCHParams(
291297
bucket_offsets=bucket_offsets,
292298
bucket_sizes=bucket_sizes,
293299
enable_optimizer_offloading=True,
300+
backend_return_whole_row=(backend_type == BackendType.DRAM),
294301
eviction_policy=eviction_policy,
295302
)
296303

@@ -1395,7 +1402,9 @@ def __init__(
13951402
self._config.embedding_tables, self._pg
13961403
)
13971404
)
1398-
_populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec, config)
1405+
_populate_zero_collision_tbe_params(
1406+
ssd_tbe_params, self._bucket_spec, config, backend_type
1407+
)
13991408
compute_kernel = config.embedding_tables[0].compute_kernel
14001409
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
14011410

@@ -2201,7 +2210,9 @@ def __init__(
22012210
self._config.embedding_tables, self._pg
22022211
)
22032212
)
2204-
_populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec, config)
2213+
_populate_zero_collision_tbe_params(
2214+
ssd_tbe_params, self._bucket_spec, config, backend_type
2215+
)
22052216
compute_kernel = config.embedding_tables[0].compute_kernel
22062217
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
22072218

torchrec/distributed/embedding.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,17 @@ def _initialize_torch_state(self) -> None: # noqa
956956
(
957957
[
958958
# assuming virtual table only supports rw sharding for now
959-
0 if dim == 0 else dim_size
959+
# When backend return whole row, need to respect dim(1)
960+
# otherwise will see shard dim exceeded tensor dim error
961+
(
962+
0
963+
if dim == 0
964+
else (
965+
local_shards[0].metadata.shard_sizes[1]
966+
if dim == 1
967+
else dim_size
968+
)
969+
)
960970
for dim, dim_size in enumerate(
961971
self._name_to_table_size[table_name]
962972
)

0 commit comments

Comments
 (0)