Skip to content

[WIP] Draft to remove torch.cuda calls and use Platform APIs. #20721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,30 @@ def stateless_init_device_torch_dist_pg(
def device_count(cls) -> int:
return cuda_device_count_stateless()

@classmethod
def empty_cache(cls, ):
torch.cuda.empty_cache()

@classmethod
def reset_peak_memory_stats(cls):
torch.cuda.reset_peak_memory_stats()

@classmethod
def mem_get_info(cls):
return torch.cuda.mem_get_info()

@classmethod
def memory_stats(cls):
return torch.cuda.memory_stats()

@classmethod
def memory_reserved(cls):
return torch.cuda.memory_reserved()

@classmethod
def synchronize(cls):
return torch.cuda.synchronize()


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
24 changes: 24 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,30 @@ def stateless_init_device_torch_dist_pg(
"""
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")

@classmethod
def empty_cache(cls, ):
raise NotImplementedError

@classmethod
def reset_peak_memory_stats(cls):
raise NotImplementedError

@classmethod
def mem_get_info(cls):
raise NotImplementedError

@classmethod
def memory_stats(cls):
raise NotImplementedError

@classmethod
def memory_reserved(cls):
raise NotImplementedError

@classmethod
def synchronize(cls):
torch.accelerator.synchronize()


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
26 changes: 25 additions & 1 deletion vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

# check and update parallel config
parallel_config = vllm_config.parallel_config
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"

if parallel_config.distributed_executor_backend is None:
if parallel_config.world_size > 1:
Expand Down Expand Up @@ -195,3 +195,27 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
@classmethod
def device_count(cls) -> int:
return torch.xpu.device_count()

@classmethod
def empty_cache(cls, ):
torch.xpu.empty_cache()

@classmethod
def reset_peak_memory_stats(cls):
torch.xpu.reset_peak_memory_stats()

@classmethod
def mem_get_info(cls):
return torch.xpu.mem_get_info()

@classmethod
def memory_stats(cls):
return torch.xpu.memory_stats()

@classmethod
def memory_reserved(cls):
return torch.xpu.memory_reserved()

@classmethod
def synchronize(cls):
return torch.xpu.synchronize()
16 changes: 10 additions & 6 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,16 +2557,18 @@ def measure(self):
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get(
from vllm.platforms import current_platform

self.torch_peak = current_platform.memory_stats().get(
"allocated_bytes.all.peak", 0)

self.free_memory, self.total_memory = torch.cuda.mem_get_info()
self.free_memory, self.total_memory = current_platform.mem_get_info()
self.cuda_memory = self.total_memory - self.free_memory

# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved()
self.torch_memory = current_platform.memory_reserved()

self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time()
Expand Down Expand Up @@ -2658,9 +2660,11 @@ def memory_profiling(

The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
from vllm.platforms import current_platform

gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
current_platform.empty_cache()
current_platform.reset_peak_memory_stats()

result = MemoryProfilingResult()

Expand All @@ -2673,7 +2677,7 @@ def memory_profiling(
yield result

gc.collect()
torch.cuda.empty_cache()
current_platform.empty_cache()

result.after_profile.measure()

Expand Down
13 changes: 10 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,12 +346,19 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties
"""
self.device_properties = torch.cuda.get_device_properties(self.device)
self.num_sms = self.device_properties.multi_processor_count
from vllm.platforms import current_platform

if current_platform.is_cuda():
self.device_properties = torch.cuda.get_device_properties(
self.device)
self.num_sms = self.device_properties.multi_processor_count
else:
self.num_sms = 0

# Note: used for model runner override.
def _sync_device(self) -> None:
torch.cuda.synchronize()
from vllm.platforms import current_platform
current_platform.synchronize()

def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
Expand Down
16 changes: 9 additions & 7 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
self.profiler = None

def sleep(self, level: int = 1) -> None:
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
free_bytes_before_sleep = current_platform.mem_get_info()[0]

# Save the buffers before level 2 sleep
if level == 2:
Expand All @@ -91,7 +91,7 @@ def sleep(self, level: int = 1) -> None:

allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
free_bytes_after_sleep, total = current_platform.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
Expand All @@ -118,7 +118,8 @@ def initialize_cache(self, num_gpu_blocks: int,
self.cache_config.num_cpu_blocks = num_cpu_blocks

def init_device(self):
if self.device_config.device.type == "cuda":
if self.device_config.device.type == "cuda" or \
self.device_config.device.type == "xpu":
Comment on lines +121 to +122
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This condition hardcodes device types 'cuda' and 'xpu', which is not easily extensible to other platforms like ROCm. The environment variables being set within this block (TORCH_NCCL_AVOID_RECORD_STREAMS and NCCL_ASYNC_ERROR_HANDLING) are specific to NCCL, which is used by the CUDA backend. It's not clear if these are applicable or correct for the XPU backend, which uses ccl or xccl. Applying NCCL-specific workarounds to other platforms could lead to unexpected behavior or bugs.

A better approach would be to abstract this platform-specific setup into the Platform classes. For example, you could add an on_worker_init() method to the Platform interface and implement it in CudaPlatform and XpuPlatform with their respective environment variable settings. This would make the code more modular, maintainable, and less prone to errors when adding new hardware backends.

# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
Expand All @@ -129,12 +130,13 @@ def init_device(self):

# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
self.device = torch.device(
f"{current_platform.device_name}:{self.local_rank}")
current_platform.set_device(self.device)

_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
current_platform.empty_cache()

# take current memory snapshot
self.init_snapshot = MemorySnapshot()
Expand Down Expand Up @@ -198,8 +200,8 @@ def determine_available_memory(self) -> int:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
current_platform.empty_cache()
current_platform.reset_peak_memory_stats()
GiB = lambda b: b / GiB_bytes

# Execute a forward pass with dummy inputs to profile the memory usage
Expand Down
33 changes: 0 additions & 33 deletions vllm/v1/worker/xpu_model_runner.py

This file was deleted.

Loading