-
Notifications
You must be signed in to change notification settings - Fork 566
extend mergekit to make it work on xpu #580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
All contributors have signed the CLA ✍️ ✅ |
I have read the CLA Document and I hereby sign the CLA |
recheck |
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
@cg123 , could you pls help review? Thx very much |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR, this is a very welcome addition! Left some comments but nothing huge.
@@ -529,7 +529,7 @@ def _move_tensors( | |||
self, value: Any, device: torch.device, non_blocking: Optional[bool] = None | |||
) -> Any: | |||
if non_blocking is None: | |||
non_blocking = device.type == "cuda" | |||
non_blocking = device.type in ["cuda", "xpu"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
stream = torch.cuda.Stream(device=device) | ||
with torch.cuda.stream(stream): | ||
stream = ( | ||
torch.Stream(device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should just use torch.Stream
in all cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can. The only concern is torch.Stream
is only available from 2.5
, but mergekit
's torch
dependency is >= 2.0
now. So if users installs torch 2.4, it will crash. Pls let me know your insights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, fair! This is good for now then. I'll look at it if I bump the minimum torch version in the future.
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
@cg123 , thx very much for you comments. I've updated per your comments, pls help review again, thx. |
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
@cg123 , thx very much for you comments. I've updated per your comments, pls help review again, thx. |
@cg123 , could you pls help review again, thx. |
Think this is good to go now. Thanks for your patience and the pr! |
since PyTorch 2.5, xpu has been a built-in device of PyTorch.
In this PR, I extend the accelerator device of
mergekit
from CUDA only to Intel XPU(which is the name of Intel's GPU) and potentially to more devices, by adding a newdevice
field inMergeOption
.I've passed the UT cases with
@cg123 , pls help review and let me know what I need do next step, thx very much.