Skip to content

extend mergekit to make it work on xpu #580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 6, 2025
Merged

extend mergekit to make it work on xpu #580

merged 13 commits into from
Jun 6, 2025

Conversation

yao-matrix
Copy link
Contributor

@yao-matrix yao-matrix commented May 16, 2025

since PyTorch 2.5, xpu has been a built-in device of PyTorch.

In this PR, I extend the accelerator device of mergekit from CUDA only to Intel XPU(which is the name of Intel's GPU) and potentially to more devices, by adding a new device field in MergeOption.

I've passed the UT cases with

================== 70 passed, 2 warnings in 60.49s (0:01:00) ===================

@cg123 , pls help review and let me know what I need do next step, thx very much.

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Copy link

github-actions bot commented May 16, 2025

All contributors have signed the CLA ✍️ ✅
Posted by the CLA Assistant Lite bot.

@yao-matrix
Copy link
Contributor Author

I have read the CLA Document and I hereby sign the CLA

@yao-matrix
Copy link
Contributor Author

recheck

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
@yao-matrix yao-matrix changed the title extend mergekit to work on xpu extend mergekit to make it work on xpu May 16, 2025
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
@yao-matrix
Copy link
Contributor Author

@cg123 , could you pls help review? Thx very much

Copy link
Collaborator

@cg123 cg123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR, this is a very welcome addition! Left some comments but nothing huge.

@@ -529,7 +529,7 @@ def _move_tensors(
self, value: Any, device: torch.device, non_blocking: Optional[bool] = None
) -> Any:
if non_blocking is None:
non_blocking = device.type == "cuda"
non_blocking = device.type in ["cuda", "xpu"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

stream = torch.cuda.Stream(device=device)
with torch.cuda.stream(stream):
stream = (
torch.Stream(device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should just use torch.Stream in all cases?

Copy link
Contributor Author

@yao-matrix yao-matrix May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can. The only concern is torch.Stream is only available from 2.5, but mergekit's torch dependency is >= 2.0 now. So if users installs torch 2.4, it will crash. Pls let me know your insights.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, fair! This is good for now then. I'll look at it if I bump the minimum torch version in the future.

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
@yao-matrix
Copy link
Contributor Author

@cg123 , thx very much for you comments. I've updated per your comments, pls help review again, thx.

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
@yao-matrix
Copy link
Contributor Author

@cg123 , thx very much for you comments. I've updated per your comments, pls help review again, thx.

@yao-matrix
Copy link
Contributor Author

@cg123 , could you pls help review again, thx.

@cg123
Copy link
Collaborator

cg123 commented Jun 6, 2025

Think this is good to go now. Thanks for your patience and the pr!

@cg123 cg123 merged commit d86ddbc into arcee-ai:main Jun 6, 2025
5 checks passed
@github-actions github-actions bot locked and limited conversation to collaborators Jun 6, 2025
@yao-matrix yao-matrix deleted the xpu branch June 8, 2025 23:00
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants