You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading (pytorch#3153)
Summary:
X-link: facebookresearch/FBGEMM#1500
X-link: pytorch/FBGEMM#4435
# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.
# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs
Reviewed By: emlin
Differential Revision:
D77666892
Privacy Context Container: L1138451
0 commit comments