-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Performance] EPLB Execution Optimization #20990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,82 @@ | ||||||||||||
from abc import ABC, abstractmethod | ||||||||||||
import torch | ||||||||||||
import torch_npu | ||||||||||||
import torch.distributed as dist | ||||||||||||
|
||||||||||||
class DeviceBackend(ABC): | ||||||||||||
"""硬件后端抽象基类,定义统一接口""" | ||||||||||||
|
||||||||||||
@abstractmethod | ||||||||||||
def synchronize(self) -> None: | ||||||||||||
"""同步当前设备的所有操作""" | ||||||||||||
pass | ||||||||||||
|
||||||||||||
@abstractmethod | ||||||||||||
def create_buffer_like(self, tensor: torch.Tensor) -> torch.Tensor: | ||||||||||||
"""创建与输入张量相同类型和设备的缓冲区""" | ||||||||||||
pass | ||||||||||||
|
||||||||||||
@abstractmethod | ||||||||||||
def all_gather(self, output_tensor_list: list[torch.Tensor], | ||||||||||||
input_tensor: torch.Tensor, group=None) -> None: | ||||||||||||
"""执行all_gather集体通信操作""" | ||||||||||||
pass | ||||||||||||
|
||||||||||||
@abstractmethod | ||||||||||||
def batch_isend_irecv(self, p2p_ops: list[P2POp]) -> list[dist.Work]: | ||||||||||||
"""执行批量异步发送和接收操作""" | ||||||||||||
pass | ||||||||||||
|
||||||||||||
@abstractmethod | ||||||||||||
def barrier(self, group=None) -> None: | ||||||||||||
"""执行屏障同步""" | ||||||||||||
pass | ||||||||||||
|
||||||||||||
|
||||||||||||
class CUDABackend(DeviceBackend): | ||||||||||||
"""CUDA/NVIDIA GPU后端实现""" | ||||||||||||
|
||||||||||||
def synchronize(self) -> None: | ||||||||||||
torch.cuda.synchronize() | ||||||||||||
|
||||||||||||
def create_buffer_like(self, tensor: torch.Tensor) -> torch.Tensor: | ||||||||||||
return torch.empty_like(tensor, device='cuda') | ||||||||||||
|
||||||||||||
def all_gather(self, output_tensor_list: list[torch.Tensor], | ||||||||||||
input_tensor: torch.Tensor, group=None) -> None: | ||||||||||||
dist.all_gather(output_tensor_list, input_tensor, group=group) | ||||||||||||
|
||||||||||||
def batch_isend_irecv(self, p2p_ops: list[P2POp]) -> list[dist.Work]: | ||||||||||||
return dist.batch_isend_irecv(p2p_ops) | ||||||||||||
|
||||||||||||
def barrier(self, group=None) -> None: | ||||||||||||
dist.barrier(group=group) | ||||||||||||
|
||||||||||||
|
||||||||||||
class NPUBackend(DeviceBackend): | ||||||||||||
"""NPU后端实现""" | ||||||||||||
|
||||||||||||
def synchronize(self) -> None: | ||||||||||||
|
||||||||||||
pass | ||||||||||||
Comment on lines
+59
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Similarly, other methods like
Suggested change
|
||||||||||||
|
||||||||||||
def create_buffer_like(self, tensor: torch.Tensor) -> torch.Tensor: | ||||||||||||
return torch.empty_like(tensor, device='npu') | ||||||||||||
|
||||||||||||
def all_gather(self, output_tensor_list: list[torch.Tensor], | ||||||||||||
input_tensor: torch.Tensor, group=None) -> None: | ||||||||||||
pass | ||||||||||||
|
||||||||||||
def batch_isend_irecv(self, p2p_ops: list[P2POp]) -> list[dist.Work]: | ||||||||||||
pass | ||||||||||||
|
||||||||||||
def barrier(self, group=None) -> None: | ||||||||||||
pass | ||||||||||||
|
||||||||||||
|
||||||||||||
# 根据可用硬件创建适当的后端 | ||||||||||||
def create_device_backend(use_cuda: bool = True) -> DeviceBackend: | ||||||||||||
if use_cuda : | ||||||||||||
return CUDABackend() | ||||||||||||
elif use_npu : | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||
return NPUBackend() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
P2POp
is used as a type hint here but is not defined or imported in this file. This will cause aNameError
at runtime. The same issue exists on lines 49 and 70. You should import it fromtorch.distributed
at the top of the file.