Skip to content

[RFC]: Continue on Device agnostic API abstraction to current_platform.XXX #20708

Open
@xuechendi

Description

@xuechendi

co-author with @jikunshang

Motivation.

This RFC is aiming to reuse GPUWorker and GPUModelRunner for any GPGPU devices, such as CUDA, ROCM and Intel GPU(aka: XPU).

  • By doing so, we can remove redundant duplication by adding a new XXXWorker/XXXModelRunner and derive from GPUWorker/GPUModelRunner
  • Any feature implemented in GPUWorker/GPUModelRunner such as logitsProcessor, samplingOutput optimization, spec_decode can be shared to all GPGPU hardware.

Status & Challenge

  • Previous RFC from Huawei has made significant work done through - [RFC]: Make device agnostic for diverse hardware support #9268

  • Currently, GPUWorker and GPUModelRunner is assumed that it will only be used by CUDA and RocM, so hard-coded to cuda API will be used in above two files. Ex:torch.cuda.XXX or tensor.to('cuda')

Proposed Change.

  1. Add abstract API into platforms/interface.py and implement in cuda.py, rocm.py, xpu.py.
  2. update any tensor.to('cuda') or tensor.cuda() to use tensor.to(current_platform.device).
  3. Add a skip check in case of API mismatch.
  4. Add static check to PR pre_commit to indicate future contributor to use current_platform instead calling torch.cuda directly.
Image

Plan

  1. Add abstract API
  • torch.cuda.empty_cache -> current_platform.empty_cache
  • torch.cuda.set_device -> current_platform.set_device
  • torch.cuda.reset_peak_memory_stats -> current_platform.reset_peak_memory_stats
  • torch.cuda.mem_get_info -> current_platform.mem_get_info
  • torch.cuda.memory_stats -> current_platform.memory_stats
  • torch.cuda.memory_reserved -> current_platform.memory_reserved
  • torch.cuda.synchronize -> current_platform.synchronize
  1. abstract device / dist_device
  1. add skip check
  • add current_platform.is_graph_mode_supported()
  • TBD
  1. Add static check
  • TBD

** PR list **
#19410 => Merged
#20751 => Open

** additional **

torch APIs support list in torch.cuda and torch.xpu - By Q2 25

no. api cuda xpu
1 current_device
2 current_stream
3 device
4 device_count
5 get_device_capability
6 get_device_name
7 get_device_properties
8 init
9 is_available
10 is_initialized
11 set_device
12 set_stream
13 stream
14 synchronize
15 manual_seed
16 manual_seed_all
17 Stream
18 Event
19 empty_cache
20 mem_get_info
21 memory_stats, memory_stats_as_nested_dict
22 memory_allocated
23 max_memory_allocated
24 memory_reserved
25 reset_peak_memory_stats

Not matched APIs

no. api cuda xpu
1 Cudart enable with alternative API
2 is_current_stream_capturing enable with alternative API
3 graph_pool_handle check and skip
4 CUDAGraph check and skip
5 Graph check and skip
6 CUDAPluggableAllocator check and skip
7 nvtx.range_push check and skip
8 nvtx.range_pop check and skip
9 nvtx.range check and skip
10 _lazy_init enable with alternative API
11 _is_compiled enable with alternative API
12 _device_count_amdsmi enable with alternative API
13 _device_count_nvml enable with alternative API
14 tunnable enable with alternative API

Feedback Period.

No response

CC List.

@jikunshang @simon-mo @WoosukKwon @youkaichao @gshtras

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions