diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 6b0a126ca9b..0fb81ff6a6f 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -37,6 +37,7 @@ from vllm.distributed.parallel_state import get_ep_group, get_node_count from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts +from vllm.platforms import current_platform from .rebalance_algo import rebalance_experts from .rebalance_execute import rearrange_expert_weights_inplace @@ -348,7 +349,7 @@ def rearrange(self, time_start = None is_main_rank = ep_rank == 0 if is_main_rank: - torch.cuda.synchronize() + current_platform.synchronize() time_start = time.perf_counter() logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") @@ -423,7 +424,7 @@ def rearrange(self, if is_main_rank: assert time_start is not None - torch.cuda.synchronize() + current_platform.synchronize() time_end = time.perf_counter() logger.info( "Rearranged experts%sin %.2f seconds.", diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 2ef8587b559..87ba6f1d66f 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -13,6 +13,8 @@ from torch.distributed import (P2POp, ProcessGroup, all_gather, batch_isend_irecv, get_global_rank) +from vllm.platforms import current_platform + def idx_local_to_global( local_idx: int, @@ -292,7 +294,7 @@ def rearrange_expert_weights_inplace( for layer in range(num_moe_layers): # NOTE(bowen): We need this synchronize to run, but I don't know why. # If you figure out the reason, please let me know -- thank you! - torch.cuda.synchronize() + current_platform.synchronize() shuffle_layer( num_local_physical_experts, ep_rank, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index da772c11155..ec8fbe386e3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -297,7 +297,7 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: and (weight.stride(-2) * weight.element_size()) % 512 == 0): num_pad = 256 // weight.element_size() weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] - torch.cuda.empty_cache() + current_platform.empty_cache() return weight def process_weights_after_loading(self, layer: torch.nn.Module) -> None: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 59db3e6c444..f02f6f73a30 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -324,7 +324,7 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor: and (weight.stride(-2) * weight.element_size()) % 512 == 0): num_pad = 256 // weight.element_size() weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] - torch.cuda.empty_cache() + current_platform.empty_cache() return weight def process_weights_after_loading(self, layer: Module) -> None: diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index d22b1e7b67d..24fb2f2039d 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -755,7 +755,7 @@ def load_weights(self, model: nn.Module, **stacked_quant_state_dict } self._bind_quant_states_to_params(model, stacked_quant_state_dict) - torch.cuda.empty_cache() + current_platform.empty_cache() def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index a0aa981f951..1f4fc8e4d39 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -275,3 +275,30 @@ def default_v1(cls, model_config) -> bool: arch = cls.get_cpu_architecture() return (cls.supports_v1(model_config) and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM)) + + @classmethod + def empty_cache(cls): + pass + + @classmethod + def reset_peak_memory_stats(cls): + pass + + @classmethod + def mem_get_info(cls): + # FIXME: impl + return None + + @classmethod + def memory_stats(cls): + # FIXME: impl + return None + + @classmethod + def memory_reserved(cls): + # FIXME: impl + return None + + @classmethod + def synchronize(cls): + pass \ No newline at end of file diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 878f8f77edf..3411fa223ab 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -443,6 +443,30 @@ def stateless_init_device_torch_dist_pg( def device_count(cls) -> int: return cuda_device_count_stateless() + @classmethod + def empty_cache(cls): + torch.cuda.empty_cache() + + @classmethod + def reset_peak_memory_stats(cls): + torch.cuda.reset_peak_memory_stats() + + @classmethod + def mem_get_info(cls): + return torch.cuda.mem_get_info() + + @classmethod + def memory_stats(cls): + return torch.cuda.memory_stats() + + @classmethod + def memory_reserved(cls): + return torch.cuda.memory_reserved() + + @classmethod + def synchronize(cls): + torch.cuda.synchronize() + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index ae675bcc8d2..59e5f297a8e 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -549,6 +549,30 @@ def stateless_init_device_torch_dist_pg( """ raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + @classmethod + def empty_cache(cls): + raise NotImplementedError + + @classmethod + def reset_peak_memory_stats(cls): + raise NotImplementedError + + @classmethod + def mem_get_info(cls): + raise NotImplementedError + + @classmethod + def memory_stats(cls): + raise NotImplementedError + + @classmethod + def memory_reserved(cls): + raise NotImplementedError + + @classmethod + def synchronize(cls): + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 04637f5c7aa..ea15b28eed8 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -459,3 +459,27 @@ def stateless_init_device_torch_dist_pg( @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() + + @classmethod + def empty_cache(cls): + torch.cuda.empty_cache() + + @classmethod + def reset_peak_memory_stats(cls): + torch.cuda.reset_peak_memory_stats() + + @classmethod + def mem_get_info(cls): + return torch.cuda.mem_get_info() + + @classmethod + def memory_stats(cls): + return torch.cuda.memory_stats() + + @classmethod + def memory_reserved(cls): + return torch.cuda.memory_reserved() + + @classmethod + def synchronize(cls): + torch.cuda.synchronize() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c4530c1dfaa..3894e2dca7d 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -194,3 +194,40 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: @classmethod def device_count(cls) -> int: return torch.xpu.device_count() + + @classmethod + def empty_cache(cls): + torch.xpu.empty_cache() + + @classmethod + def reset_peak_memory_stats(cls): + torch.xpu.reset_peak_memory_stats() + + @classmethod + def mem_get_info(cls): + if cls.is_data_center_gpu(): + return torch.xpu.mem_get_info() + else: + # we provide this function due to `torch.xpu.mem_get_info()` doesn't + # return correct free_gpu_memory on intel client GPU. We need to + # calculate/estiamte it. + _, total_gpu_memory = torch.xpu.mem_get_info() + # FIXME: memory_allocated() doesn't count non-torch allocations, + # and we don't have any API to get it. so we mark it as 128MB. + used_memory = torch.xpu.memory_allocated() + non_torch_allocations = 128 * 1024 * 1024 + free_gpu_memory = total_gpu_memory - (used_memory + + non_torch_allocations) + return free_gpu_memory, total_gpu_memory + + @classmethod + def memory_stats(cls): + return torch.xpu.memory_stats() + + @classmethod + def memory_reserved(cls): + return torch.xpu.memory_reserved() + + @classmethod + def synchronize(cls): + torch.xpu.synchronize() diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 495e359aa6d..bb3ab0e5ec2 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2553,20 +2553,22 @@ def __post_init__(self): def measure(self): # we measure the torch peak memory usage via allocated_bytes, - # rather than `torch.cuda.memory_reserved()` . - # After `torch.cuda.reset_peak_memory_stats()`, - # `torch.cuda.memory_reserved()` will keep growing, and only shrink - # when we call `torch.cuda.empty_cache()` or OOM happens. - self.torch_peak = torch.cuda.memory_stats().get( + # rather than `current_platform.memory_reserved()` . + # After `current_platform.reset_peak_memory_stats()`, + # `current_platform.memory_reserved()` will keep growing, and only + # shrink when we call `current_platform.empty_cache()` or OOM happens. + from vllm.platforms import current_platform + + self.torch_peak = current_platform.memory_stats().get( "allocated_bytes.all.peak", 0) - self.free_memory, self.total_memory = torch.cuda.mem_get_info() + self.free_memory, self.total_memory = current_platform.mem_get_info() self.cuda_memory = self.total_memory - self.free_memory - # torch.cuda.memory_reserved() is how many bytes + # current_platform.memory_reserved() is how many bytes # PyTorch gets from cuda (by calling cudaMalloc, etc.) # this is used to measure the non-torch memory usage - self.torch_memory = torch.cuda.memory_reserved() + self.torch_memory = current_platform.memory_reserved() self.non_torch_memory = self.cuda_memory - self.torch_memory self.timestamp = time.time() @@ -2658,9 +2660,11 @@ def memory_profiling( The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). """ # noqa + from vllm.platforms import current_platform + gc.collect() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + current_platform.empty_cache() + current_platform.reset_peak_memory_stats() result = MemoryProfilingResult() @@ -2673,7 +2677,7 @@ def memory_profiling( yield result gc.collect() - torch.cuda.empty_cache() + current_platform.empty_cache() result.after_profile.measure() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4551cb2df98..3f5efbfe80c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -38,6 +38,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors @@ -345,12 +346,16 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: def _init_device_properties(self) -> None: """Initialize attributes from torch.cuda.get_device_properties """ - self.device_properties = torch.cuda.get_device_properties(self.device) - self.num_sms = self.device_properties.multi_processor_count + if current_platform.is_cuda(): + self.device_properties = torch.cuda.get_device_properties( + self.device) + self.num_sms = self.device_properties.multi_processor_count + else: + self.num_sms = None # Note: used for model runner override. def _sync_device(self) -> None: - torch.cuda.synchronize() + current_platform.synchronize() def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -2270,7 +2275,7 @@ def capture_model(self) -> None: compilation_counter.num_gpu_runner_capture_triggers += 1 start_time = time.perf_counter() - start_free_gpu_memory = torch.cuda.mem_get_info()[0] + start_free_gpu_memory = current_platform.mem_get_info()[0] # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes @@ -2296,7 +2301,7 @@ def capture_model(self) -> None: skip_eplb=True) end_time = time.perf_counter() - end_free_gpu_memory = torch.cuda.mem_get_info()[0] + end_free_gpu_memory = current_platform.mem_get_info()[0] elapsed_time = end_time - start_time cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory # This usually takes 5~20 seconds. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6458b55777a..944e71a11a1 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -83,7 +83,7 @@ def __init__( def sleep(self, level: int = 1) -> None: from vllm.device_allocator.cumem import CuMemAllocator - free_bytes_before_sleep = torch.cuda.mem_get_info()[0] + free_bytes_before_sleep = current_platform.mem_get_info()[0] # Save the buffers before level 2 sleep if level == 2: @@ -95,7 +95,7 @@ def sleep(self, level: int = 1) -> None: allocator = CuMemAllocator.get_instance() allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple()) - free_bytes_after_sleep, total = torch.cuda.mem_get_info() + free_bytes_after_sleep, total = current_platform.mem_get_info() freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep used_bytes = total - free_bytes_after_sleep assert freed_bytes >= 0, "Memory usage increased after sleeping." @@ -135,12 +135,13 @@ def init_device(self): # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") + self.device = torch.device( + f"{current_platform.device_name}:{self.local_rank}") current_platform.set_device(self.device) _check_if_gpu_supports_dtype(self.model_config.dtype) gc.collect() - torch.cuda.empty_cache() + current_platform.empty_cache() # take current memory snapshot self.init_snapshot = MemorySnapshot() @@ -209,8 +210,8 @@ def determine_available_memory(self) -> int: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() + current_platform.empty_cache() + current_platform.reset_peak_memory_stats() GiB = lambda b: b / GiB_bytes # Execute a forward pass with dummy inputs to profile the memory usage @@ -285,7 +286,7 @@ def compile_or_warm_up_model(self) -> None: # sampling related tensors of max possible shape to avoid memory # fragmentation issue. # NOTE: This is called after `capture_model` on purpose to prevent - # memory buffers from being cleared by `torch.cuda.empty_cache`. + # memory buffers from being cleared by `current_platform.empty_cache`. if get_pp_group().is_last_rank: max_num_reqs = min(self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens) diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 59f8d0fcf5b..3f0c0959f10 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -25,9 +25,3 @@ def __init__( super().__init__(vllm_config, device) # FIXME: To be verified. self.cascade_attn_enabled = False - - def _init_device_properties(self) -> None: - self.num_sms = None - - def _sync_device(self) -> None: - torch.xpu.synchronize() diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index da271b2159a..aea0b159048 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -10,6 +10,7 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.utils import MemorySnapshot from vllm.v1.worker.gpu_worker import (Worker, init_worker_distributed_environment) from vllm.v1.worker.xpu_model_runner import XPUModelRunner @@ -51,91 +52,17 @@ def __init__( else: self.profiler = None - # we provide this function due to `torch.xpu.mem_get_info()` doesn't - # return correct free_gpu_memory on intel client GPU. We need to - # calculate/estiamte it. - def xpu_get_mem_info(self): - if current_platform.is_data_center_gpu(): - return torch.xpu.mem_get_info() - else: - _, total_gpu_memory = torch.xpu.mem_get_info() - # FIXME: memory_allocated() doesn't count non-torch allocations, - # and we don't have any API to get it. so we mark it as 128MB. - used_memory = torch.xpu.memory_allocated() - non_torch_allocations = 128 * 1024 * 1024 - free_gpu_memory = total_gpu_memory - (used_memory + - non_torch_allocations) - return free_gpu_memory, total_gpu_memory - - @torch.inference_mode() - def determine_available_memory(self) -> int: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.xpu.empty_cache() - torch.xpu.reset_peak_memory_stats() - - free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info() - current_allocated_bytes = torch.xpu.memory_allocated() - msg = ("Before memory profiling run, " - f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " - f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") - logger.info(msg) - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - self.model_runner.profile_run() - - free_gpu_memory, _ = self.xpu_get_mem_info() - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - assert self.init_gpu_memory > free_gpu_memory, ( - "Error in memory profiling. " - f"Initial free memory {self.init_gpu_memory}, current free memory" - f" {free_gpu_memory}. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") - - # Get the peak memory allocation recorded by torch - peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"] - - torch.xpu.empty_cache() - torch_allocated_bytes = torch.xpu.memory_stats( - )["allocated_bytes.all.current"] - total_allocated_bytes = self.xpu_get_mem_info( - )[1] - self.xpu_get_mem_info()[0] - - non_torch_allocations = total_allocated_bytes - torch_allocated_bytes - if non_torch_allocations > 0: - peak_memory += non_torch_allocations - available_kv_cache_memory = ( - total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) - - msg = ("After memory profiling run, " - f"peak memory usage is {peak_memory / 1024**2:.2f} MB," - f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " - f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " - f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") - logger.info(msg) - - return int(available_kv_cache_memory) - def init_device(self): if self.device_config.device.type == "xpu" and current_platform.is_xpu( ): self.device = torch.device(f"xpu:{self.local_rank}") current_platform.set_device(self.device) - torch.xpu.empty_cache() - self.init_gpu_memory = torch.xpu.get_device_properties( - self.local_rank).total_memory + current_platform.empty_cache() + self.init_gpu_memory = current_platform.get_device_total_memory( + self.local_rank) + self.init_snapshot = MemorySnapshot() + self.requested_memory = (self.init_snapshot.total_memory * + self.cache_config.gpu_memory_utilization) else: raise RuntimeError( f"Not support device type: {self.device_config.device}")