Skip to content

Commit cc867be

Browse files
authored
[V1] Reuse V0's memory_profiling util for gpu worker memory profiling (#19312)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
1 parent 3a7cd62 commit cc867be

File tree

2 files changed

+51
-53
lines changed

2 files changed

+51
-53
lines changed

vllm/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2269,6 +2269,8 @@ def kill_process_tree(pid: int):
22692269
class MemorySnapshot:
22702270
"""Memory snapshot."""
22712271
torch_peak: int = 0
2272+
free_memory: int = 0
2273+
total_memory: int = 0
22722274
cuda_memory: int = 0
22732275
torch_memory: int = 0
22742276
non_torch_memory: int = 0
@@ -2288,8 +2290,8 @@ def measure(self):
22882290
self.torch_peak = torch.cuda.memory_stats().get(
22892291
"allocated_bytes.all.peak", 0)
22902292

2291-
self.cuda_memory = torch.cuda.mem_get_info(
2292-
)[1] - torch.cuda.mem_get_info()[0]
2293+
self.free_memory, self.total_memory = torch.cuda.mem_get_info()
2294+
self.cuda_memory = self.total_memory - self.free_memory
22932295

22942296
# torch.cuda.memory_reserved() is how many bytes
22952297
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
@@ -2302,6 +2304,8 @@ def measure(self):
23022304
def __sub__(self, other: MemorySnapshot) -> MemorySnapshot:
23032305
return MemorySnapshot(
23042306
torch_peak=self.torch_peak - other.torch_peak,
2307+
free_memory=self.free_memory - other.free_memory,
2308+
total_memory=self.total_memory - other.total_memory,
23052309
cuda_memory=self.cuda_memory - other.cuda_memory,
23062310
torch_memory=self.torch_memory - other.torch_memory,
23072311
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
@@ -2323,6 +2327,16 @@ class MemoryProfilingResult:
23232327
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
23242328
profile_time: float = 0.0
23252329

2330+
def __repr__(self) -> str:
2331+
return (f"Memory profiling takes {self.profile_time:.2f} seconds. "
2332+
f"Total non KV cache memory: "
2333+
f"{(self.non_kv_cache_memory / GiB_bytes):.2f}GiB; "
2334+
f"torch peak memory increase: "
2335+
f"{(self.torch_peak_increase / GiB_bytes):.2f}GiB; "
2336+
f"non-torch forward increase memory: "
2337+
f"{(self.non_torch_increase / GiB_bytes):.2f}GiB; "
2338+
f"weights memory: {(self.weights_memory / GiB_bytes):.2f}GiB.")
2339+
23262340

23272341
@contextlib.contextmanager
23282342
def memory_profiling(

vllm/v1/worker/gpu_worker.py

Lines changed: 35 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from vllm.model_executor import set_random_seed
2323
from vllm.platforms import current_platform
2424
from vllm.sequence import IntermediateTensors
25-
from vllm.utils import GiB_bytes
25+
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
2626
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2727
from vllm.v1.outputs import ModelRunnerOutput
2828
from vllm.v1.utils import report_usage_stats
@@ -130,20 +130,22 @@ def init_device(self):
130130
_check_if_gpu_supports_dtype(self.model_config.dtype)
131131
gc.collect()
132132
torch.cuda.empty_cache()
133-
self.init_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
134-
requested_memory = (total_gpu_memory *
135-
self.cache_config.gpu_memory_utilization)
136-
if self.init_gpu_memory < requested_memory:
133+
134+
# take current memory snapshot
135+
self.init_snapshot = MemorySnapshot()
136+
self.requested_memory = (self.init_snapshot.total_memory *
137+
self.cache_config.gpu_memory_utilization)
138+
if self.init_snapshot.free_memory < self.requested_memory:
137139
GiB = lambda b: round(b / GiB_bytes, 2)
138140
raise ValueError(
139-
f"Free memory on device ({GiB(self.init_gpu_memory)}/"
140-
f"{GiB(total_gpu_memory)} GiB) on startup is less than "
141-
f"desired GPU memory utilization "
141+
f"Free memory on device "
142+
f"({GiB(self.init_snapshot.free_memory)}/"
143+
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup "
144+
f"is less than desired GPU memory utilization "
142145
f"({self.cache_config.gpu_memory_utilization}, "
143-
f"{GiB(requested_memory)} GiB). Decrease GPU memory "
146+
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory "
144147
f"utilization or reduce GPU memory used by other processes."
145148
)
146-
147149
else:
148150
raise RuntimeError(
149151
f"Not support device type: {self.device_config.device}")
@@ -192,57 +194,39 @@ def determine_available_memory(self) -> int:
192194
"""
193195
torch.cuda.empty_cache()
194196
torch.cuda.reset_peak_memory_stats()
197+
GiB = lambda b: b / GiB_bytes
195198

196-
_, total_gpu_memory = torch.cuda.mem_get_info()
197199
# Execute a forward pass with dummy inputs to profile the memory usage
198200
# of the model.
199-
self.model_runner.profile_run()
201+
with memory_profiling(
202+
self.init_snapshot,
203+
weights_memory=int(
204+
self.model_runner.model_memory_usage)) as profile_result:
205+
self.model_runner.profile_run()
200206

201-
free_gpu_memory, _ = torch.cuda.mem_get_info()
207+
free_gpu_memory = profile_result.after_profile.free_memory
202208
# NOTE(woosuk): Here we assume that the other processes using the same
203209
# GPU did not change their memory usage during the profiling.
204-
assert self.init_gpu_memory > free_gpu_memory, (
210+
assert self.init_snapshot.free_memory > free_gpu_memory, (
205211
"Error in memory profiling. "
206-
f"Initial free memory {self.init_gpu_memory/GiB_bytes} GiB, "
207-
f"current free memory {free_gpu_memory/GiB_bytes} GiB. "
208-
f"This happens when the GPU memory was not properly cleaned up "
209-
f"before initializing the vLLM instance.")
210-
211-
# Get the peak memory allocation recorded by torch
212-
peak_torch_memory = torch.cuda.memory_stats(
213-
)["allocated_bytes.all.peak"]
214-
215-
# Check for any memory left around that may have been allocated on the
216-
# gpu outside of `torch`. NCCL operations, for example, can use a few
217-
# GB during a forward pass.
218-
torch.cuda.empty_cache()
219-
torch_allocated_bytes = torch.cuda.memory_stats(
220-
)["allocated_bytes.all.current"]
221-
222-
# Reset after emptying torch cache
223-
free_gpu_memory = torch.cuda.mem_get_info()[0]
212+
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, "
213+
f"current free memory {GiB(free_gpu_memory)} GiB. "
214+
"This happens when other processes sharing the same container "
215+
"release GPU memory while vLLM is profiling during initialization. "
216+
"To fix this, ensure consistent GPU memory allocation or "
217+
"isolate vLLM in its own container.")
218+
available_kv_cache_memory = self.requested_memory \
219+
- profile_result.non_kv_cache_memory
224220

225-
# Total forward allocation (current) is equal to the diff in free memory
226-
fwd_alloc_bytes = self.init_gpu_memory - free_gpu_memory
227-
# We assume current non-torch allocation is equal to peak
228-
non_torch_alloc_bytes = max(0, fwd_alloc_bytes - torch_allocated_bytes)
229-
# Total forward allocation (peak) is peak torch + non-torch
230-
peak_memory = peak_torch_memory + non_torch_alloc_bytes
231-
232-
available_kv_cache_memory = (
233-
total_gpu_memory * self.cache_config.gpu_memory_utilization -
234-
peak_memory)
235-
236-
GiB = lambda b: b / GiB_bytes
237221
logger.debug(
238222
"Initial free memory: %.2f GiB, free memory: %.2f GiB, "
239-
"total GPU memory: %.2f GiB", GiB(self.init_gpu_memory),
240-
GiB(free_gpu_memory), GiB(total_gpu_memory))
241-
logger.debug(
242-
"Peak torch memory: %.2f GiB, non-torch forward-pass memory: "
243-
"%.2f GiB, available KVCache memory: %.2f GiB",
244-
GiB(peak_torch_memory), GiB(non_torch_alloc_bytes),
245-
GiB(available_kv_cache_memory))
223+
"requested GPU memory: %.2f GiB",
224+
GiB(self.init_snapshot.free_memory), GiB(free_gpu_memory),
225+
GiB(self.requested_memory))
226+
logger.debug(profile_result)
227+
logger.info("Available KV cache memory: %.2f GiB",
228+
GiB(available_kv_cache_memory))
229+
gc.collect()
246230

247231
return int(available_kv_cache_memory)
248232

0 commit comments

Comments
 (0)