Open
Description
co-author with @jikunshang
Motivation.
This RFC is aiming to reuse GPUWorker
and GPUModelRunner
for any GPGPU devices, such as CUDA, ROCM and Intel GPU(aka: XPU).
- By doing so, we can remove redundant duplication by adding a new XXXWorker/XXXModelRunner and derive from GPUWorker/GPUModelRunner
- Any feature implemented in GPUWorker/GPUModelRunner such as logitsProcessor, samplingOutput optimization, spec_decode can be shared to all GPGPU hardware.
Status & Challenge
-
Previous RFC from Huawei has made significant work done through - [RFC]: Make device agnostic for diverse hardware support #9268
-
Currently,
GPUWorker
andGPUModelRunner
is assumed that it will only be used by CUDA and RocM, so hard-coded to cuda API will be used in above two files. Ex:torch.cuda.XXX or tensor.to('cuda')
Proposed Change.
- Add abstract API into platforms/interface.py and implement in cuda.py, rocm.py, xpu.py.
- update any tensor.to('cuda') or tensor.cuda() to use tensor.to(current_platform.device).
- Add a skip check in case of API mismatch.
- Add static check to PR pre_commit to indicate future contributor to use current_platform instead calling torch.cuda directly.

Plan
- Add abstract API
-
torch.cuda.empty_cache
->current_platform.empty_cache
-
torch.cuda.set_device
->current_platform.set_device
-
torch.cuda.reset_peak_memory_stats
->current_platform.reset_peak_memory_stats
-
torch.cuda.mem_get_info
->current_platform.mem_get_info
-
torch.cuda.memory_stats
->current_platform.memory_stats
-
torch.cuda.memory_reserved
->current_platform.memory_reserved
-
torch.cuda.synchronize
->current_platform.synchronize
- abstract device / dist_device
-
tensor.to('cuda')
->tensor.to(current_platform.device)
- use
current_platform.dist_backend
forinit_worker_distributed_environment
. Done in [Refactor]Abstract Platform Interface for Distributed Backend and Add xccl Support for Intel XPU #19410
- add skip check
- add
current_platform.is_graph_mode_supported()
- TBD
- Add static check
- TBD
** PR list **
#19410 => Merged
#20751 => Open
** additional **
torch APIs support list in torch.cuda and torch.xpu - By Q2 25
no. | api | cuda | xpu |
---|---|---|---|
1 | current_device | ✅ | ✅ |
2 | current_stream | ✅ | ✅ |
3 | device | ✅ | ✅ |
4 | device_count | ✅ | ✅ |
5 | get_device_capability | ✅ | ✅ |
6 | get_device_name | ✅ | ✅ |
7 | get_device_properties | ✅ | ✅ |
8 | init | ✅ | ✅ |
9 | is_available | ✅ | ✅ |
10 | is_initialized | ✅ | ✅ |
11 | set_device | ✅ | ✅ |
12 | set_stream | ✅ | ✅ |
13 | stream | ✅ | ✅ |
14 | synchronize | ✅ | ✅ |
15 | manual_seed | ✅ | ✅ |
16 | manual_seed_all | ✅ | ✅ |
17 | Stream | ✅ | ✅ |
18 | Event | ✅ | ✅ |
19 | empty_cache | ✅ | ✅ |
20 | mem_get_info | ✅ | ✅ |
21 | memory_stats, memory_stats_as_nested_dict | ✅ | ✅ |
22 | memory_allocated | ✅ | ✅ |
23 | max_memory_allocated | ✅ | ✅ |
24 | memory_reserved | ✅ | ✅ |
25 | reset_peak_memory_stats | ✅ | ✅ |
Not matched APIs
no. | api | cuda | xpu |
---|---|---|---|
1 | Cudart | ✅ | enable with alternative API |
2 | is_current_stream_capturing | ✅ | enable with alternative API |
3 | graph_pool_handle | ✅ | check and skip |
4 | CUDAGraph | ✅ | check and skip |
5 | Graph | ✅ | check and skip |
6 | CUDAPluggableAllocator | ✅ | check and skip |
7 | nvtx.range_push | ✅ | check and skip |
8 | nvtx.range_pop | ✅ | check and skip |
9 | nvtx.range | ✅ | check and skip |
10 | _lazy_init | ✅ | enable with alternative API |
11 | _is_compiled | ✅ | enable with alternative API |
12 | _device_count_amdsmi | ✅ | enable with alternative API |
13 | _device_count_nvml | ✅ | enable with alternative API |
14 | tunnable | ✅ | enable with alternative API |
Feedback Period.
No response
CC List.
@jikunshang @simon-mo @WoosukKwon @youkaichao @gshtras
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.