-
Notifications
You must be signed in to change notification settings - Fork 566
extend mergekit to make it work on xpu #580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
70f07c9
9655271
1b1ee80
cfd9b73
6db889e
ced573e
c57f362
b7ecbb3
15b6727
1fcd025
30e63f2
e4bfb30
b756f5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -529,7 +529,7 @@ def _move_tensors( | |
self, value: Any, device: torch.device, non_blocking: Optional[bool] = None | ||
) -> Any: | ||
if non_blocking is None: | ||
non_blocking = device.type == "cuda" | ||
non_blocking = device.type in ["cuda", "xpu"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
if isinstance(value, torch.Tensor): | ||
if value.device == device: | ||
return value | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,9 +75,15 @@ def __init__( | |
self.results: Dict[TaskHandle, Any] = {} | ||
self.storage_device = storage_device | ||
|
||
self.accelerator_type = ( | ||
getattr(torch, torch.acclerator.current_accelerator().type) | ||
yao-matrix marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if hasattr(torch, "accelerator") | ||
else "cuda" | ||
) | ||
torch_accelerator_module = getattr(torch, self.accelerator_type) | ||
if num_gpus is None: | ||
num_gpus = torch.cuda.device_count() | ||
LOG.info(f"Using {num_gpus} GPUs for parallel execution") | ||
num_gpus = torch_accelerator_module.device_count() | ||
LOG.info(f"Using {num_gpus} {accelerator_type} for parallel execution") | ||
|
||
self.universe = TaskUniverse(targets) | ||
self.targets = set([self.universe.get_handle(t) for t in targets]) | ||
|
@@ -309,12 +315,14 @@ def _assign_islands_to_gpus( | |
continue | ||
# don't need to sort, inner executor will handle | ||
island_tasks = [TaskHandle(self.universe, idx) for idx in island] | ||
# assign to GPU with fewest tasks (load balancing) | ||
# assign to accelerator with fewest tasks (load balancing) | ||
device_idx = min( | ||
range(num_gpus), | ||
key=lambda i: len(assignments.get(torch.device(f"cuda:{i}"), [])), | ||
key=lambda i: len( | ||
assignments.get(torch.device(f"{self.accelerator_type}:{i}"), []) | ||
), | ||
) | ||
device = torch.device(f"cuda:{device_idx}") | ||
device = torch.device(f"{self.accelerator_type}:{device_idx}") | ||
assignments[device] = assignments.get(device, []) + island_tasks | ||
return assignments | ||
|
||
|
@@ -339,9 +347,16 @@ def _device_worker( | |
quiet: Whether to suppress progress bar output | ||
""" | ||
LOG.debug(f"Device {device} starting") | ||
torch_accelerator_module = getattr(torch, self.accelerator_type) | ||
with torch.device(device): | ||
stream = torch.cuda.Stream(device=device) | ||
with torch.cuda.stream(stream): | ||
stream = ( | ||
torch.Stream(device=device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we should just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can. The only concern is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, fair! This is good for now then. I'll look at it if I bump the minimum torch version in the future. |
||
if self.accelerator_type == "xpu" | ||
else torch.cuda.Stream(device=device) | ||
) | ||
with ( | ||
stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream) | ||
): | ||
exec = Executor( | ||
targets=task_list, | ||
math_device=device, | ||
|
@@ -358,5 +373,5 @@ def _device_worker( | |
): | ||
result = None | ||
self.task_completion_queue.put((task_handle._index, result)) | ||
torch.cuda.synchronize(device=device) | ||
torch_accelerator_module.synchronize(device=device) | ||
LOG.debug(f"Device {device} done") |
Uh oh!
There was an error while loading. Please reload this page.