Skip to content

[1/N] Refactor platform API to reduce torch.cuda call #20751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm/distributed/eplb/eplb_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from vllm.distributed.parallel_state import get_ep_group, get_node_count
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
from vllm.platforms import current_platform

from .rebalance_algo import rebalance_experts
from .rebalance_execute import rearrange_expert_weights_inplace
Expand Down Expand Up @@ -348,7 +349,7 @@ def rearrange(self,
time_start = None
is_main_rank = ep_rank == 0
if is_main_rank:
torch.cuda.synchronize()
current_platform.synchronize()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below. Same can be done with the standard pytorch API available starting from torch 2.6:

Suggested change
current_platform.synchronize()
torch.accelerator.synchronize()

Are there actual benefits to define similar device abstraction on vLLM level? Using standard pytorch API will help to have a leaner vLLM code base. See https://docs.pytorch.org/docs/stable/generated/torch.accelerator.synchronize.html#torch.accelerator.synchronize

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we'd better implment base class Platform::synchronize() method using torch.accelerator.synchronize() and leave it for platforms to implement their own in case there are any tricks, like pytorch/pytorch#155668

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be good compromise. Which torch version does vLLM target across device backends? Note that torch.accelerator is available from 2.6. If vLLM needs to support wider torch range, this can be a clear reason to abstract this API on the vLLM level. Also, if you see any missing APIs in torch.acclerator, please, feedback - we are willing to take care of that on pytorch level.

time_start = time.perf_counter()
logger.info("Rearranging experts %s...",
"(profile)" if is_profile else "")
Expand Down Expand Up @@ -423,7 +424,7 @@ def rearrange(self,

if is_main_rank:
assert time_start is not None
torch.cuda.synchronize()
current_platform.synchronize()
time_end = time.perf_counter()
logger.info(
"Rearranged experts%sin %.2f seconds.",
Expand Down
4 changes: 3 additions & 1 deletion vllm/distributed/eplb/rebalance_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from torch.distributed import (P2POp, ProcessGroup, all_gather,
batch_isend_irecv, get_global_rank)

from vllm.platforms import current_platform


def idx_local_to_global(
local_idx: int,
Expand Down Expand Up @@ -292,7 +294,7 @@ def rearrange_expert_weights_inplace(
for layer in range(num_moe_layers):
# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize()
current_platform.synchronize()
shuffle_layer(
num_local_physical_experts,
ep_rank,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
current_platform.empty_cache()
return weight

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
current_platform.empty_cache()
return weight

def process_weights_after_loading(self, layer: Module) -> None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def load_weights(self, model: nn.Module,
**stacked_quant_state_dict
}
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache()
current_platform.empty_cache()

def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
27 changes: 27 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,30 @@ def default_v1(cls, model_config) -> bool:
arch = cls.get_cpu_architecture()
return (cls.supports_v1(model_config) and arch
in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM))

@classmethod
def empty_cache(cls):
pass

@classmethod
def reset_peak_memory_stats(cls):
pass

@classmethod
def mem_get_info(cls):
# FIXME: impl
return None
Comment on lines +288 to +290
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The mem_get_info method should be implemented for the CPU platform to provide memory information. Returning None might lead to unexpected behavior or errors in other parts of the code that rely on this information. This is also applicable to the other FIXME comments in this file.


@classmethod
def memory_stats(cls):
# FIXME: impl
return None

@classmethod
def memory_reserved(cls):
# FIXME: impl
return None
Comment on lines +287 to +300
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The stubs for mem_get_info, memory_stats, and memory_reserved return None, which will cause runtime errors in vllm.utils.memory_profiling when running on the CPU platform.

  • mem_get_info() returning None will cause a TypeError on unpacking.
  • memory_stats() returning None will cause an AttributeError when .get() is called.
  • memory_reserved() returning None will cause a TypeError in arithmetic operations.

These methods should return sensible default values for the CPU platform to prevent crashes. For example:

  • mem_get_info: (0, 0)
  • memory_stats: {}
  • memory_reserved: 0
Suggested change
@classmethod
def mem_get_info(cls):
# FIXME: impl
return None
@classmethod
def memory_stats(cls):
# FIXME: impl
return None
@classmethod
def memory_reserved(cls):
# FIXME: impl
return None
@classmethod
def mem_get_info(cls):
# FIXME: impl
# Returning (0, 0) as a placeholder for (free, total) memory.
return (0, 0)
@classmethod
def memory_stats(cls):
# FIXME: impl
# Returning an empty dict as a placeholder for memory stats.
return {}
@classmethod
def memory_reserved(cls):
# FIXME: impl
# Returning 0 as a placeholder for reserved memory.
return 0


@classmethod
def synchronize(cls):
pass
24 changes: 24 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,30 @@ def stateless_init_device_torch_dist_pg(
def device_count(cls) -> int:
return cuda_device_count_stateless()

@classmethod
def empty_cache(cls):
torch.cuda.empty_cache()

@classmethod
def reset_peak_memory_stats(cls):
torch.cuda.reset_peak_memory_stats()

@classmethod
def mem_get_info(cls):
return torch.cuda.mem_get_info()

@classmethod
def memory_stats(cls):
return torch.cuda.memory_stats()

@classmethod
def memory_reserved(cls):
return torch.cuda.memory_reserved()

@classmethod
def synchronize(cls):
torch.cuda.synchronize()


# NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
Expand Down
24 changes: 24 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,30 @@ def stateless_init_device_torch_dist_pg(
"""
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")

@classmethod
def empty_cache(cls):
raise NotImplementedError

@classmethod
def reset_peak_memory_stats(cls):
raise NotImplementedError

@classmethod
def mem_get_info(cls):
raise NotImplementedError

@classmethod
def memory_stats(cls):
raise NotImplementedError

@classmethod
def memory_reserved(cls):
raise NotImplementedError

@classmethod
def synchronize(cls):
raise NotImplementedError


class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
Expand Down
24 changes: 24 additions & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,27 @@ def stateless_init_device_torch_dist_pg(
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()

@classmethod
def empty_cache(cls):
torch.cuda.empty_cache()

@classmethod
def reset_peak_memory_stats(cls):
torch.cuda.reset_peak_memory_stats()

@classmethod
def mem_get_info(cls):
return torch.cuda.mem_get_info()

@classmethod
def memory_stats(cls):
return torch.cuda.memory_stats()

@classmethod
def memory_reserved(cls):
return torch.cuda.memory_reserved()

@classmethod
def synchronize(cls):
torch.cuda.synchronize()
37 changes: 37 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,40 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
@classmethod
def device_count(cls) -> int:
return torch.xpu.device_count()

@classmethod
def empty_cache(cls):
torch.xpu.empty_cache()

@classmethod
def reset_peak_memory_stats(cls):
torch.xpu.reset_peak_memory_stats()

@classmethod
def mem_get_info(cls):
if cls.is_data_center_gpu():
return torch.xpu.mem_get_info()
else:
# we provide this function due to `torch.xpu.mem_get_info()` doesn't
# return correct free_gpu_memory on intel client GPU. We need to
# calculate/estiamte it.
_, total_gpu_memory = torch.xpu.mem_get_info()
# FIXME: memory_allocated() doesn't count non-torch allocations,
# and we don't have any API to get it. so we mark it as 128MB.
used_memory = torch.xpu.memory_allocated()
non_torch_allocations = 128 * 1024 * 1024
free_gpu_memory = total_gpu_memory - (used_memory +
non_torch_allocations)
Comment on lines +218 to +220
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This hardcoded value 128 * 1024 * 1024 is a "magic number". To improve readability and maintainability, it's better to define it as a constant. Using an uppercase variable name is a common convention for constants in Python.

Suggested change
non_torch_allocations = 128 * 1024 * 1024
free_gpu_memory = total_gpu_memory - (used_memory +
non_torch_allocations)
NON_TORCH_ALLOCATIONS_BYTES = 128 * 1024 * 1024
free_gpu_memory = total_gpu_memory - (used_memory +
NON_TORCH_ALLOCATIONS_BYTES)

return free_gpu_memory, total_gpu_memory

@classmethod
def memory_stats(cls):
return torch.xpu.memory_stats()

@classmethod
def memory_reserved(cls):
return torch.xpu.memory_reserved()

@classmethod
def synchronize(cls):
torch.xpu.synchronize()
26 changes: 15 additions & 11 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2553,20 +2553,22 @@ def __post_init__(self):

def measure(self):
# we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get(
# rather than `current_platform.memory_reserved()` .
# After `current_platform.reset_peak_memory_stats()`,
# `current_platform.memory_reserved()` will keep growing, and only
# shrink when we call `current_platform.empty_cache()` or OOM happens.
from vllm.platforms import current_platform

self.torch_peak = current_platform.memory_stats().get(
"allocated_bytes.all.peak", 0)

self.free_memory, self.total_memory = torch.cuda.mem_get_info()
self.free_memory, self.total_memory = current_platform.mem_get_info()
self.cuda_memory = self.total_memory - self.free_memory

# torch.cuda.memory_reserved() is how many bytes
# current_platform.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved()
self.torch_memory = current_platform.memory_reserved()

self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time()
Expand Down Expand Up @@ -2658,9 +2660,11 @@ def memory_profiling(

The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
from vllm.platforms import current_platform

gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
current_platform.empty_cache()
current_platform.reset_peak_memory_stats()

result = MemoryProfilingResult()

Expand All @@ -2673,7 +2677,7 @@ def memory_profiling(
yield result

gc.collect()
torch.cuda.empty_cache()
current_platform.empty_cache()

result.after_profile.measure()

Expand Down
15 changes: 10 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -345,12 +346,16 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
def _init_device_properties(self) -> None:
"""Initialize attributes from torch.cuda.get_device_properties
"""
self.device_properties = torch.cuda.get_device_properties(self.device)
self.num_sms = self.device_properties.multi_processor_count
if current_platform.is_cuda():
self.device_properties = torch.cuda.get_device_properties(
self.device)
self.num_sms = self.device_properties.multi_processor_count
else:
Comment on lines +352 to +353
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code directly calls torch.cuda.synchronize(), but the intention of this PR is to abstract away from CUDA-specific calls. Use current_platform.synchronize() instead.

        current_platform.synchronize()

self.num_sms = None
Comment on lines 348 to +354
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The else block is executed when the platform is not CUDA, but torch.cuda.get_device_properties is CUDA specific. This could lead to errors on other platforms. It's better to use current_platform to get device properties in a platform-agnostic way.

        if current_platform.is_cuda():
            self.device_properties = torch.cuda.get_device_properties(
                self.device)
            self.num_sms = self.device_properties.multi_processor_count
        else:
            self.num_sms = None


# Note: used for model runner override.
def _sync_device(self) -> None:
torch.cuda.synchronize()
current_platform.synchronize()

def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
Expand Down Expand Up @@ -2270,7 +2275,7 @@ def capture_model(self) -> None:
compilation_counter.num_gpu_runner_capture_triggers += 1

start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
start_free_gpu_memory = current_platform.mem_get_info()[0]

# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
Expand All @@ -2296,7 +2301,7 @@ def capture_model(self) -> None:
skip_eplb=True)

end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
end_free_gpu_memory = current_platform.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
Expand Down
15 changes: 8 additions & 7 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
def sleep(self, level: int = 1) -> None:
from vllm.device_allocator.cumem import CuMemAllocator

free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
free_bytes_before_sleep = current_platform.mem_get_info()[0]

# Save the buffers before level 2 sleep
if level == 2:
Expand All @@ -95,7 +95,7 @@ def sleep(self, level: int = 1) -> None:

allocator = CuMemAllocator.get_instance()
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
free_bytes_after_sleep, total = current_platform.mem_get_info()
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
used_bytes = total - free_bytes_after_sleep
assert freed_bytes >= 0, "Memory usage increased after sleeping."
Expand Down Expand Up @@ -135,12 +135,13 @@ def init_device(self):

# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
self.device = torch.device(
f"{current_platform.device_name}:{self.local_rank}")
current_platform.set_device(self.device)

_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
current_platform.empty_cache()

# take current memory snapshot
self.init_snapshot = MemorySnapshot()
Expand Down Expand Up @@ -209,8 +210,8 @@ def determine_available_memory(self) -> int:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
current_platform.empty_cache()
current_platform.reset_peak_memory_stats()
GiB = lambda b: b / GiB_bytes

# Execute a forward pass with dummy inputs to profile the memory usage
Expand Down Expand Up @@ -285,7 +286,7 @@ def compile_or_warm_up_model(self) -> None:
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
# memory buffers from being cleared by `current_platform.empty_cache`.
if get_pp_group().is_last_rank:
max_num_reqs = min(self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens)
Expand Down
6 changes: 0 additions & 6 deletions vllm/v1/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,3 @@ def __init__(
super().__init__(vllm_config, device)
# FIXME: To be verified.
self.cascade_attn_enabled = False

def _init_device_properties(self) -> None:
self.num_sms = None

def _sync_device(self) -> None:
torch.xpu.synchronize()
Loading