Skip to content

Support get/set the whole row of metaheader+weight+optimizer from backend for checkpoint saving/loading #4435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

bobbyliujb
Copy link

Summary:

Context

In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

This diff only contains frontend changes

  • added backend_return_whole_row flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
  • added read_only_ flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
  • added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
  • by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892

Copy link

netlify bot commented Jul 2, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit cc754fc
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-fbgemm-docs/deploys/686ef8c7c227a50008e4e3dd
😎 Deploy Preview https://deploy-preview-4435--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify project configuration.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77666892

bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading

Summary:
X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading

Summary:
X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading

Summary:
X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77666892

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:
Pull Request resolved: pytorch#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500

Pull Request resolved: pytorch#3153

X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77666892

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:
X-link: facebookresearch/FBGEMM#1500

X-link: pytorch/torchrec#3153

Pull Request resolved: pytorch#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500


X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:
X-link: facebookresearch/FBGEMM#1500

X-link: pytorch/torchrec#3153


# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500


X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:
X-link: facebookresearch/FBGEMM#1500

X-link: pytorch/torchrec#3153


# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500


X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:
X-link: facebookresearch/FBGEMM#1500

X-link: pytorch/torchrec#3153


# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500

Pull Request resolved: pytorch#3153

X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77666892

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:
X-link: facebookresearch/FBGEMM#1500

X-link: pytorch/torchrec#3153

Pull Request resolved: pytorch#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500

Pull Request resolved: pytorch#3153

X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77666892

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:
X-link: facebookresearch/FBGEMM#1500

X-link: pytorch/torchrec#3153

Pull Request resolved: pytorch#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500

Pull Request resolved: pytorch#3153

X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77666892

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 7, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:
X-link: facebookresearch/FBGEMM#1500

X-link: pytorch/torchrec#3153

Pull Request resolved: pytorch#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Differential Revision: D77666892
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 9, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500


X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Reviewed By: emlin

Differential Revision:
D77666892

Privacy Context Container: L1138451
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 9, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500


X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Reviewed By: emlin

Differential Revision:
D77666892

Privacy Context Container: L1138451
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 9, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500

Pull Request resolved: pytorch#3153

X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Reviewed By: emlin

Differential Revision:
D77666892

Privacy Context Container: L1138451
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77666892

bobbyliujb pushed a commit to bobbyliujb/FBGEMM-1 that referenced this pull request Jul 9, 2025
…kend for checkpoint saving/loading (pytorch#4435)

Summary:
X-link: facebookresearch/FBGEMM#1500

X-link: pytorch/torchrec#3153

Pull Request resolved: pytorch#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Reviewed By: emlin

Differential Revision:
D77666892

Privacy Context Container: L1138451
…kend for checkpoint saving/loading (pytorch#4435)

Summary:
X-link: facebookresearch/FBGEMM#1500

X-link: pytorch/torchrec#3153

Pull Request resolved: pytorch#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Reviewed By: emlin

Differential Revision:
D77666892

Privacy Context Container: L1138451
bobbyliujb pushed a commit to bobbyliujb/torchrec that referenced this pull request Jul 9, 2025
…kend for checkpoint saving/loading (pytorch#3153)

Summary:
X-link: facebookresearch/FBGEMM#1500

Pull Request resolved: pytorch#3153

X-link: pytorch/FBGEMM#4435

# Context
In our current KVZCH cp loading flow, we will keep hold of weight_id, weight, optimizer tensors throughout the checkpoint loading lifecycle, and at the end when all these tensors are downloaded in hand, we will explicitly call "apply_state_dict" to actually write them by chunk to the backend to ensure id->weight and id->opt are mapped correctly. The problem is when we have large number of weights, we will be short of memory since we need to hold all 3 tensors (double memory issue). To solve this challenge, we are going to save the whole row of (metaheader + weight + opt) as the same "weight" tensor during checkpoint saving, and when downloading the checkpoint, we will be able to extract the id from the header, and directly write the weight+opt part to the backend by id. When loading cp for optimizer, we added a no-op KVTensor, so it won't need to write to backend for optimizer states again.

# This diff only contains frontend changes
* added `backend_return_whole_row` flag in KVZCH params, with validation to make sure it's only True when opt_offloading is used
* added `read_only_` flag in KVTensorWrapper to be used for checkpoint calls. When read-only=True, all write operations to this KVT will be no-op
* added metadata recalc for optimizer state dict, because we are now returning read-only KVT for opt state dict, and model store will need to correct the global metadata before creating the save plan for KVZCH opt tensors
* by default the opt offloading and return whole row is False on trunk, so should not break existing KVZCH runs

Reviewed By: emlin

Differential Revision:
D77666892

Privacy Context Container: L1138451
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D77666892

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants