From b8e8567799233e5ab8008984ab6e0d74304f660e Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 26 Jun 2025 05:29:16 +0800 Subject: [PATCH 1/2] refactor platform Signed-off-by: Kunshang Ji --- vllm/platforms/cuda.py | 24 ++++++++++++++++++++++++ vllm/platforms/interface.py | 24 ++++++++++++++++++++++++ vllm/platforms/xpu.py | 26 +++++++++++++++++++++++++- vllm/utils/__init__.py | 16 ++++++++++------ vllm/v1/worker/gpu_model_runner.py | 13 ++++++++++--- vllm/v1/worker/gpu_worker.py | 16 +++++++++------- 6 files changed, 102 insertions(+), 17 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b53d7e71a03..fede1c6b488 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -422,6 +422,30 @@ def stateless_init_device_torch_dist_pg( def device_count(cls) -> int: return cuda_device_count_stateless() + @classmethod + def empty_cache(cls, ): + torch.cuda.empty_cache() + + @classmethod + def reset_peak_memory_stats(cls): + torch.cuda.reset_peak_memory_stats() + + @classmethod + def mem_get_info(cls): + return torch.cuda.mem_get_info() + + @classmethod + def memory_stats(cls): + return torch.cuda.memory_stats() + + @classmethod + def memory_reserved(cls): + return torch.cuda.memory_reserved() + + @classmethod + def synchronize(cls): + return torch.cuda.synchronize() + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d3060685e98..3bbcb18fad5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -548,6 +548,30 @@ def stateless_init_device_torch_dist_pg( """ raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + @classmethod + def empty_cache(cls, ): + raise NotImplementedError + + @classmethod + def reset_peak_memory_stats(cls): + raise NotImplementedError + + @classmethod + def mem_get_info(cls): + raise NotImplementedError + + @classmethod + def memory_stats(cls): + raise NotImplementedError + + @classmethod + def memory_reserved(cls): + raise NotImplementedError + + @classmethod + def synchronize(cls): + torch.accelerator.synchronize() + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index fb69ed36af0..eb6be153ee2 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -116,7 +116,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # check and update parallel config parallel_config = vllm_config.parallel_config - parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" + parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" if parallel_config.distributed_executor_backend is None: if parallel_config.world_size > 1: @@ -195,3 +195,27 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: @classmethod def device_count(cls) -> int: return torch.xpu.device_count() + + @classmethod + def empty_cache(cls, ): + torch.xpu.empty_cache() + + @classmethod + def reset_peak_memory_stats(cls): + torch.xpu.reset_peak_memory_stats() + + @classmethod + def mem_get_info(cls): + return torch.xpu.mem_get_info() + + @classmethod + def memory_stats(cls): + return torch.xpu.memory_stats() + + @classmethod + def memory_reserved(cls): + return torch.xpu.memory_reserved() + + @classmethod + def synchronize(cls): + return torch.xpu.synchronize() diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index cf7320a19e4..d4a4421cd03 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2557,16 +2557,18 @@ def measure(self): # After `torch.cuda.reset_peak_memory_stats()`, # `torch.cuda.memory_reserved()` will keep growing, and only shrink # when we call `torch.cuda.empty_cache()` or OOM happens. - self.torch_peak = torch.cuda.memory_stats().get( + from vllm.platforms import current_platform + + self.torch_peak = current_platform.memory_stats().get( "allocated_bytes.all.peak", 0) - self.free_memory, self.total_memory = torch.cuda.mem_get_info() + self.free_memory, self.total_memory = current_platform.mem_get_info() self.cuda_memory = self.total_memory - self.free_memory # torch.cuda.memory_reserved() is how many bytes # PyTorch gets from cuda (by calling cudaMalloc, etc.) # this is used to measure the non-torch memory usage - self.torch_memory = torch.cuda.memory_reserved() + self.torch_memory = current_platform.memory_reserved() self.non_torch_memory = self.cuda_memory - self.torch_memory self.timestamp = time.time() @@ -2658,9 +2660,11 @@ def memory_profiling( The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). """ # noqa + from vllm.platforms import current_platform + gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + current_platform.empty_cache() + current_platform.reset_peak_memory_stats() result = MemoryProfilingResult() @@ -2673,7 +2677,7 @@ def memory_profiling( yield result gc.collect() - torch.cuda.empty_cache() + current_platform.empty_cache() result.after_profile.measure() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ef03626cf14..615c99cab3c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -346,12 +346,19 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _init_device_properties(self) -> None: """Initialize attributes from torch.cuda.get_device_properties """ - self.device_properties = torch.cuda.get_device_properties(self.device) - self.num_sms = self.device_properties.multi_processor_count + from vllm.platforms import current_platform + + if current_platform.is_cuda(): + self.device_properties = torch.cuda.get_device_properties( + self.device) + self.num_sms = self.device_properties.multi_processor_count + else: + self.num_sms = 0 # Note: used for model runner override. def _sync_device(self) -> None: - torch.cuda.synchronize() + from vllm.platforms import current_platform + current_platform.synchronize() def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 916052ca5eb..d907e1d2b16 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -79,7 +79,7 @@ def __init__( self.profiler = None def sleep(self, level: int = 1) -> None: - free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + free_bytes_before_sleep = current_platform.mem_get_info()[0] # Save the buffers before level 2 sleep if level == 2: @@ -91,7 +91,7 @@ def sleep(self, level: int = 1) -> None: allocator = CuMemAllocator.get_instance() allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) - free_bytes_after_sleep, total = torch.cuda.mem_get_info() + free_bytes_after_sleep, total = current_platform.mem_get_info() freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep used_bytes = total - free_bytes_after_sleep assert freed_bytes >= 0, "Memory usage increased after sleeping." @@ -118,7 +118,8 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_cpu_blocks = num_cpu_blocks def init_device(self): - if self.device_config.device.type == "cuda": + if self.device_config.device.type == "cuda" or \ + self.device_config.device.type == "xpu": # torch.distributed.all_reduce does not free the input tensor until # the synchronization point. This causes the memory usage to grow # as the number of all_reduce calls increases. This env var disables @@ -129,12 +130,13 @@ def init_device(self): # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") + self.device = torch.device( + f"{current_platform.device_name}:{self.local_rank}") current_platform.set_device(self.device) _check_if_gpu_supports_dtype(self.model_config.dtype) gc.collect() - torch.cuda.empty_cache() + current_platform.empty_cache() # take current memory snapshot self.init_snapshot = MemorySnapshot() @@ -198,8 +200,8 @@ def determine_available_memory(self) -> int: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + current_platform.empty_cache() + current_platform.reset_peak_memory_stats() GiB = lambda b: b / GiB_bytes # Execute a forward pass with dummy inputs to profile the memory usage From b3e6666b3ed5f0433ef14cf1726a93a8a570139b Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Thu, 26 Jun 2025 05:58:31 +0800 Subject: [PATCH 2/2] remove xpu worker/runner Signed-off-by: Kunshang Ji --- vllm/v1/worker/xpu_model_runner.py | 33 ------ vllm/v1/worker/xpu_worker.py | 165 ----------------------------- 2 files changed, 198 deletions(-) delete mode 100644 vllm/v1/worker/xpu_model_runner.py delete mode 100644 vllm/v1/worker/xpu_worker.py diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py deleted file mode 100644 index 4cedc913c2a..00000000000 --- a/vllm/v1/worker/xpu_model_runner.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING - -import torch - -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.v1.worker.gpu_model_runner import GPUModelRunner - -if TYPE_CHECKING: - pass - -logger = init_logger(__name__) - - -class XPUModelRunner(GPUModelRunner): - """A model runner for XPU devices.""" - - def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - ): - super().__init__(vllm_config, device) - # FIXME: To be verified. - self.cascade_attn_enabled = False - - def _init_device_properties(self) -> None: - pass - - def _sync_device(self) -> None: - torch.xpu.synchronize() diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py deleted file mode 100644 index da271b2159a..00000000000 --- a/vllm/v1/worker/xpu_worker.py +++ /dev/null @@ -1,165 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os - -import torch -import torch.distributed - -import vllm.envs as envs -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform -from vllm.v1.worker.gpu_worker import (Worker, - init_worker_distributed_environment) -from vllm.v1.worker.xpu_model_runner import XPUModelRunner - -logger = init_logger(__name__) - - -class XPUWorker(Worker): - """A XPU worker class.""" - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - ): - super().__init__(vllm_config, local_rank, rank, - distributed_init_method, is_driver_worker) - device_config = self.device_config - assert device_config.device_type == "xpu" - assert current_platform.is_xpu() - - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.XPU, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - else: - self.profiler = None - - # we provide this function due to `torch.xpu.mem_get_info()` doesn't - # return correct free_gpu_memory on intel client GPU. We need to - # calculate/estiamte it. - def xpu_get_mem_info(self): - if current_platform.is_data_center_gpu(): - return torch.xpu.mem_get_info() - else: - _, total_gpu_memory = torch.xpu.mem_get_info() - # FIXME: memory_allocated() doesn't count non-torch allocations, - # and we don't have any API to get it. so we mark it as 128MB. - used_memory = torch.xpu.memory_allocated() - non_torch_allocations = 128 * 1024 * 1024 - free_gpu_memory = total_gpu_memory - (used_memory + - non_torch_allocations) - return free_gpu_memory, total_gpu_memory - - @torch.inference_mode() - def determine_available_memory(self) -> int: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.xpu.empty_cache() - torch.xpu.reset_peak_memory_stats() - - free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info() - current_allocated_bytes = torch.xpu.memory_allocated() - msg = ("Before memory profiling run, " - f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " - f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") - logger.info(msg) - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - self.model_runner.profile_run() - - free_gpu_memory, _ = self.xpu_get_mem_info() - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - assert self.init_gpu_memory > free_gpu_memory, ( - "Error in memory profiling. " - f"Initial free memory {self.init_gpu_memory}, current free memory" - f" {free_gpu_memory}. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") - - # Get the peak memory allocation recorded by torch - peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"] - - torch.xpu.empty_cache() - torch_allocated_bytes = torch.xpu.memory_stats( - )["allocated_bytes.all.current"] - total_allocated_bytes = self.xpu_get_mem_info( - )[1] - self.xpu_get_mem_info()[0] - - non_torch_allocations = total_allocated_bytes - torch_allocated_bytes - if non_torch_allocations > 0: - peak_memory += non_torch_allocations - available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) - - msg = ("After memory profiling run, " - f"peak memory usage is {peak_memory / 1024**2:.2f} MB," - f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " - f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") - logger.info(msg) - - return int(available_kv_cache_memory) - - def init_device(self): - if self.device_config.device.type == "xpu" and current_platform.is_xpu( - ): - self.device = torch.device(f"xpu:{self.local_rank}") - current_platform.set_device(self.device) - torch.xpu.empty_cache() - self.init_gpu_memory = torch.xpu.get_device_properties( - self.local_rank).total_memory - else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") - - ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "drmfd") - ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") - ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", - str(self.parallel_config.world_size)) - os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE - os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT - os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE - os.environ["LOCAL_RANK"] = str(self.local_rank) - - init_worker_distributed_environment(self.vllm_config, self.rank, - self.distributed_init_method, - self.local_rank, - current_platform.dist_backend) - - # global all_reduce needed for overall oneccl warm up - torch.distributed.all_reduce(torch.zeros(1).xpu()) - - # Set random seed. - set_random_seed(self.model_config.seed) - - # Construct the model runner - self.model_runner = XPUModelRunner( # type: ignore - self.vllm_config, self.device)