Skip to content

Commit 7890945

Browse files
committed
refactor platform
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent cc876d0 commit 7890945

File tree

8 files changed

+119
-103
lines changed

8 files changed

+119
-103
lines changed

vllm/platforms/cuda.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,30 @@ def stateless_init_device_torch_dist_pg(
424424
def device_count(cls) -> int:
425425
return cuda_device_count_stateless()
426426

427+
@classmethod
428+
def empty_cache(cls, ):
429+
torch.cuda.empty_cache()
430+
431+
@classmethod
432+
def reset_peak_memory_stats(cls):
433+
torch.cuda.reset_peak_memory_stats()
434+
435+
@classmethod
436+
def mem_get_info(cls):
437+
return torch.cuda.mem_get_info()
438+
439+
@classmethod
440+
def memory_stats(cls):
441+
return torch.cuda.memory_stats()
442+
443+
@classmethod
444+
def memory_reserved(cls):
445+
return torch.cuda.memory_reserved()
446+
447+
@classmethod
448+
def synchronize(cls):
449+
return torch.cuda.synchronize()
450+
427451

428452
# NVML utils
429453
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/interface.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,30 @@ def stateless_init_device_torch_dist_pg(
548548
"""
549549
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
550550

551+
@classmethod
552+
def empty_cache(cls, ):
553+
raise NotImplementedError
554+
555+
@classmethod
556+
def reset_peak_memory_stats(cls):
557+
raise NotImplementedError
558+
559+
@classmethod
560+
def mem_get_info(cls):
561+
raise NotImplementedError
562+
563+
@classmethod
564+
def memory_stats(cls):
565+
raise NotImplementedError
566+
567+
@classmethod
568+
def memory_reserved(cls):
569+
raise NotImplementedError
570+
571+
@classmethod
572+
def synchronize(cls):
573+
torch.accelerator.synchronize()
574+
551575

552576
class UnspecifiedPlatform(Platform):
553577
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/xpu.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,40 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
194194
@classmethod
195195
def device_count(cls) -> int:
196196
return torch.xpu.device_count()
197+
198+
@classmethod
199+
def empty_cache(cls, ):
200+
torch.xpu.empty_cache()
201+
202+
@classmethod
203+
def reset_peak_memory_stats(cls):
204+
torch.xpu.reset_peak_memory_stats()
205+
206+
@classmethod
207+
def mem_get_info(cls):
208+
if cls.is_data_center_gpu():
209+
return torch.xpu.mem_get_info()
210+
else:
211+
# we provide this function due to `torch.xpu.mem_get_info()` doesn't
212+
# return correct free_gpu_memory on intel client GPU. We need to
213+
# calculate/estiamte it.
214+
_, total_gpu_memory = torch.xpu.mem_get_info()
215+
# FIXME: memory_allocated() doesn't count non-torch allocations,
216+
# and we don't have any API to get it. so we mark it as 128MB.
217+
used_memory = torch.xpu.memory_allocated()
218+
non_torch_allocations = 128 * 1024 * 1024
219+
free_gpu_memory = total_gpu_memory - (used_memory +
220+
non_torch_allocations)
221+
return free_gpu_memory, total_gpu_memory
222+
223+
@classmethod
224+
def memory_stats(cls):
225+
return torch.xpu.memory_stats()
226+
227+
@classmethod
228+
def memory_reserved(cls):
229+
return torch.xpu.memory_reserved()
230+
231+
@classmethod
232+
def synchronize(cls):
233+
return torch.xpu.synchronize()

vllm/utils/__init__.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2557,16 +2557,18 @@ def measure(self):
25572557
# After `torch.cuda.reset_peak_memory_stats()`,
25582558
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
25592559
# when we call `torch.cuda.empty_cache()` or OOM happens.
2560-
self.torch_peak = torch.cuda.memory_stats().get(
2560+
from vllm.platforms import current_platform
2561+
2562+
self.torch_peak = current_platform.memory_stats().get(
25612563
"allocated_bytes.all.peak", 0)
25622564

2563-
self.free_memory, self.total_memory = torch.cuda.mem_get_info()
2565+
self.free_memory, self.total_memory = current_platform.mem_get_info()
25642566
self.cuda_memory = self.total_memory - self.free_memory
25652567

25662568
# torch.cuda.memory_reserved() is how many bytes
25672569
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
25682570
# this is used to measure the non-torch memory usage
2569-
self.torch_memory = torch.cuda.memory_reserved()
2571+
self.torch_memory = current_platform.memory_reserved()
25702572

25712573
self.non_torch_memory = self.cuda_memory - self.torch_memory
25722574
self.timestamp = time.time()
@@ -2658,9 +2660,11 @@ def memory_profiling(
26582660
26592661
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
26602662
""" # noqa
2663+
from vllm.platforms import current_platform
2664+
26612665
gc.collect()
2662-
torch.cuda.empty_cache()
2663-
torch.cuda.reset_peak_memory_stats()
2666+
current_platform.empty_cache()
2667+
current_platform.reset_peak_memory_stats()
26642668

26652669
result = MemoryProfilingResult()
26662670

@@ -2673,7 +2677,7 @@ def memory_profiling(
26732677
yield result
26742678

26752679
gc.collect()
2676-
torch.cuda.empty_cache()
2680+
current_platform.empty_cache()
26772681

26782682
result.after_profile.measure()
26792683

vllm/v1/worker/gpu_model_runner.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from vllm.multimodal import MULTIMODAL_REGISTRY
3939
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
4040
from vllm.multimodal.utils import group_mm_inputs_by_modality
41+
from vllm.platforms import current_platform
4142
from vllm.pooling_params import PoolingParams
4243
from vllm.sampling_params import SamplingType
4344
from vllm.sequence import IntermediateTensors
@@ -345,12 +346,16 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
345346
def _init_device_properties(self) -> None:
346347
"""Initialize attributes from torch.cuda.get_device_properties
347348
"""
348-
self.device_properties = torch.cuda.get_device_properties(self.device)
349-
self.num_sms = self.device_properties.multi_processor_count
349+
if current_platform.is_cuda():
350+
self.device_properties = torch.cuda.get_device_properties(
351+
self.device)
352+
self.num_sms = self.device_properties.multi_processor_count
353+
else:
354+
self.num_sms = None
350355

351356
# Note: used for model runner override.
352357
def _sync_device(self) -> None:
353-
torch.cuda.synchronize()
358+
current_platform.synchronize()
354359

355360
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
356361
"""Update the cached states and the persistent batch with the scheduler
@@ -2264,7 +2269,7 @@ def capture_model(self) -> None:
22642269
compilation_counter.num_gpu_runner_capture_triggers += 1
22652270

22662271
start_time = time.perf_counter()
2267-
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
2272+
start_free_gpu_memory = current_platform.mem_get_info()[0]
22682273

22692274
# Trigger CUDA graph capture for specific shapes.
22702275
# Capture the large shapes first so that the smaller shapes
@@ -2288,7 +2293,7 @@ def capture_model(self) -> None:
22882293
skip_eplb=True)
22892294

22902295
end_time = time.perf_counter()
2291-
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
2296+
end_free_gpu_memory = current_platform.mem_get_info()[0]
22922297
elapsed_time = end_time - start_time
22932298
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
22942299
# This usually takes 5~20 seconds.

vllm/v1/worker/gpu_worker.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
def sleep(self, level: int = 1) -> None:
8484
from vllm.device_allocator.cumem import CuMemAllocator
8585

86-
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
86+
free_bytes_before_sleep = current_platform.mem_get_info()[0]
8787

8888
# Save the buffers before level 2 sleep
8989
if level == 2:
@@ -95,7 +95,7 @@ def sleep(self, level: int = 1) -> None:
9595

9696
allocator = CuMemAllocator.get_instance()
9797
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
98-
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
98+
free_bytes_after_sleep, total = current_platform.mem_get_info()
9999
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
100100
used_bytes = total - free_bytes_after_sleep
101101
assert freed_bytes >= 0, "Memory usage increased after sleeping."
@@ -135,12 +135,13 @@ def init_device(self):
135135

136136
# This env var set by Ray causes exceptions with graph building.
137137
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
138-
self.device = torch.device(f"cuda:{self.local_rank}")
138+
self.device = torch.device(
139+
f"{current_platform.device_name}:{self.local_rank}")
139140
current_platform.set_device(self.device)
140141

141142
_check_if_gpu_supports_dtype(self.model_config.dtype)
142143
gc.collect()
143-
torch.cuda.empty_cache()
144+
current_platform.empty_cache()
144145

145146
# take current memory snapshot
146147
self.init_snapshot = MemorySnapshot()
@@ -206,8 +207,8 @@ def determine_available_memory(self) -> int:
206207
You may limit the usage of GPU memory
207208
by adjusting the `gpu_memory_utilization` parameter.
208209
"""
209-
torch.cuda.empty_cache()
210-
torch.cuda.reset_peak_memory_stats()
210+
current_platform.empty_cache()
211+
current_platform.reset_peak_memory_stats()
211212
GiB = lambda b: b / GiB_bytes
212213

213214
# Execute a forward pass with dummy inputs to profile the memory usage

vllm/v1/worker/xpu_model_runner.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,3 @@ def __init__(
2525
super().__init__(vllm_config, device)
2626
# FIXME: To be verified.
2727
self.cascade_attn_enabled = False
28-
29-
def _init_device_properties(self) -> None:
30-
self.num_sms = None
31-
32-
def _sync_device(self) -> None:
33-
torch.xpu.synchronize()

vllm/v1/worker/xpu_worker.py

Lines changed: 7 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.logger import init_logger
1111
from vllm.model_executor import set_random_seed
1212
from vllm.platforms import current_platform
13+
from vllm.utils import MemorySnapshot
1314
from vllm.v1.worker.gpu_worker import (Worker,
1415
init_worker_distributed_environment)
1516
from vllm.v1.worker.xpu_model_runner import XPUModelRunner
@@ -51,91 +52,17 @@ def __init__(
5152
else:
5253
self.profiler = None
5354

54-
# we provide this function due to `torch.xpu.mem_get_info()` doesn't
55-
# return correct free_gpu_memory on intel client GPU. We need to
56-
# calculate/estiamte it.
57-
def xpu_get_mem_info(self):
58-
if current_platform.is_data_center_gpu():
59-
return torch.xpu.mem_get_info()
60-
else:
61-
_, total_gpu_memory = torch.xpu.mem_get_info()
62-
# FIXME: memory_allocated() doesn't count non-torch allocations,
63-
# and we don't have any API to get it. so we mark it as 128MB.
64-
used_memory = torch.xpu.memory_allocated()
65-
non_torch_allocations = 128 * 1024 * 1024
66-
free_gpu_memory = total_gpu_memory - (used_memory +
67-
non_torch_allocations)
68-
return free_gpu_memory, total_gpu_memory
69-
70-
@torch.inference_mode()
71-
def determine_available_memory(self) -> int:
72-
"""Profiles the peak memory usage of the model to determine how many
73-
KV blocks may be allocated without OOMs.
74-
The engine will first conduct a profiling of the existing memory usage.
75-
Then, it calculate the maximum possible number of GPU and CPU blocks
76-
that can be allocated with the remaining free memory.
77-
.. tip::
78-
You may limit the usage of GPU memory
79-
by adjusting the `gpu_memory_utilization` parameter.
80-
"""
81-
# Profile the memory usage of the model and get the maximum number of
82-
# cache blocks that can be allocated with the remaining free memory.
83-
torch.xpu.empty_cache()
84-
torch.xpu.reset_peak_memory_stats()
85-
86-
free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
87-
current_allocated_bytes = torch.xpu.memory_allocated()
88-
msg = ("Before memory profiling run, "
89-
f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, "
90-
f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, "
91-
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.")
92-
logger.info(msg)
93-
# Execute a forward pass with dummy inputs to profile the memory usage
94-
# of the model.
95-
self.model_runner.profile_run()
96-
97-
free_gpu_memory, _ = self.xpu_get_mem_info()
98-
# NOTE(woosuk): Here we assume that the other processes using the same
99-
# GPU did not change their memory usage during the profiling.
100-
assert self.init_gpu_memory > free_gpu_memory, (
101-
"Error in memory profiling. "
102-
f"Initial free memory {self.init_gpu_memory}, current free memory"
103-
f" {free_gpu_memory}. This happens when the GPU memory was "
104-
"not properly cleaned up before initializing the vLLM instance.")
105-
106-
# Get the peak memory allocation recorded by torch
107-
peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"]
108-
109-
torch.xpu.empty_cache()
110-
torch_allocated_bytes = torch.xpu.memory_stats(
111-
)["allocated_bytes.all.current"]
112-
total_allocated_bytes = self.xpu_get_mem_info(
113-
)[1] - self.xpu_get_mem_info()[0]
114-
115-
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
116-
if non_torch_allocations > 0:
117-
peak_memory += non_torch_allocations
118-
available_kv_cache_memory = (
119-
total_gpu_memory * self.cache_config.gpu_memory_utilization -
120-
peak_memory)
121-
122-
msg = ("After memory profiling run, "
123-
f"peak memory usage is {peak_memory / 1024**2:.2f} MB,"
124-
f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, "
125-
f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, "
126-
f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.")
127-
logger.info(msg)
128-
129-
return int(available_kv_cache_memory)
130-
13155
def init_device(self):
13256
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
13357
):
13458
self.device = torch.device(f"xpu:{self.local_rank}")
13559
current_platform.set_device(self.device)
136-
torch.xpu.empty_cache()
137-
self.init_gpu_memory = torch.xpu.get_device_properties(
138-
self.local_rank).total_memory
60+
current_platform.empty_cache()
61+
self.init_gpu_memory = current_platform.get_device_total_memory(
62+
self.local_rank)
63+
self.init_snapshot = MemorySnapshot()
64+
self.requested_memory = (self.init_snapshot.total_memory *
65+
self.cache_config.gpu_memory_utilization)
13966
else:
14067
raise RuntimeError(
14168
f"Not support device type: {self.device_config.device}")

0 commit comments

Comments
 (0)