Skip to content

Commit af40f39

Browse files
Jianbo Liufacebook-github-bot
authored andcommitted
Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading (pytorch#3153)
Summary: X-link: facebookresearch/FBGEMM#1500 X-link: pytorch/FBGEMM#4435 # Context In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again. # This diff only contains frontend changes * added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used * added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op * added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors * by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs Reviewed By: emlin Differential Revision: D77666892 Privacy Context Container: L1138451
1 parent 5a73ae5 commit af40f39

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

torchrec/distributed/embedding.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,17 @@ def _initialize_torch_state(self) -> None: # noqa
954954
(
955955
[
956956
# assuming virtual table only supports rw sharding for now
957-
0 if dim == 0 else dim_size
957+
# When backend return whole row, need to respect dim(1)
958+
# otherwise will see shard dim exceeded tensor dim error
959+
(
960+
0
961+
if dim == 0
962+
else (
963+
local_shards[0].metadata.shard_sizes[1]
964+
if dim == 1
965+
else dim_size
966+
)
967+
)
958968
for dim, dim_size in enumerate(
959969
self._name_to_table_size[table_name]
960970
)

0 commit comments

Comments
 (0)