|
22 | 22 | from vllm.model_executor import set_random_seed
|
23 | 23 | from vllm.platforms import current_platform
|
24 | 24 | from vllm.sequence import IntermediateTensors
|
25 |
| -from vllm.utils import GiB_bytes |
| 25 | +from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling |
26 | 26 | from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
27 | 27 | from vllm.v1.outputs import ModelRunnerOutput
|
28 | 28 | from vllm.v1.utils import report_usage_stats
|
@@ -130,20 +130,22 @@ def init_device(self):
|
130 | 130 | _check_if_gpu_supports_dtype(self.model_config.dtype)
|
131 | 131 | gc.collect()
|
132 | 132 | torch.cuda.empty_cache()
|
133 |
| - self.init_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() |
134 |
| - requested_memory = (total_gpu_memory * |
135 |
| - self.cache_config.gpu_memory_utilization) |
136 |
| - if self.init_gpu_memory < requested_memory: |
| 133 | + |
| 134 | + # take current memory snapshot |
| 135 | + self.init_snapshot = MemorySnapshot() |
| 136 | + self.requested_memory = (self.init_snapshot.total_memory * |
| 137 | + self.cache_config.gpu_memory_utilization) |
| 138 | + if self.init_snapshot.free_memory < self.requested_memory: |
137 | 139 | GiB = lambda b: round(b / GiB_bytes, 2)
|
138 | 140 | raise ValueError(
|
139 |
| - f"Free memory on device ({GiB(self.init_gpu_memory)}/" |
140 |
| - f"{GiB(total_gpu_memory)} GiB) on startup is less than " |
141 |
| - f"desired GPU memory utilization " |
| 141 | + f"Free memory on device " |
| 142 | + f"({GiB(self.init_snapshot.free_memory)}/" |
| 143 | + f"{GiB(self.init_snapshot.total_memory)} GiB) on startup " |
| 144 | + f"is less than desired GPU memory utilization " |
142 | 145 | f"({self.cache_config.gpu_memory_utilization}, "
|
143 |
| - f"{GiB(requested_memory)} GiB). Decrease GPU memory " |
| 146 | + f"{GiB(self.requested_memory)} GiB). Decrease GPU memory " |
144 | 147 | f"utilization or reduce GPU memory used by other processes."
|
145 | 148 | )
|
146 |
| - |
147 | 149 | else:
|
148 | 150 | raise RuntimeError(
|
149 | 151 | f"Not support device type: {self.device_config.device}")
|
@@ -192,57 +194,39 @@ def determine_available_memory(self) -> int:
|
192 | 194 | """
|
193 | 195 | torch.cuda.empty_cache()
|
194 | 196 | torch.cuda.reset_peak_memory_stats()
|
| 197 | + GiB = lambda b: b / GiB_bytes |
195 | 198 |
|
196 |
| - _, total_gpu_memory = torch.cuda.mem_get_info() |
197 | 199 | # Execute a forward pass with dummy inputs to profile the memory usage
|
198 | 200 | # of the model.
|
199 |
| - self.model_runner.profile_run() |
| 201 | + with memory_profiling( |
| 202 | + self.init_snapshot, |
| 203 | + weights_memory=int( |
| 204 | + self.model_runner.model_memory_usage)) as profile_result: |
| 205 | + self.model_runner.profile_run() |
200 | 206 |
|
201 |
| - free_gpu_memory, _ = torch.cuda.mem_get_info() |
| 207 | + free_gpu_memory = profile_result.after_profile.free_memory |
202 | 208 | # NOTE(woosuk): Here we assume that the other processes using the same
|
203 | 209 | # GPU did not change their memory usage during the profiling.
|
204 |
| - assert self.init_gpu_memory > free_gpu_memory, ( |
| 210 | + assert self.init_snapshot.free_memory > free_gpu_memory, ( |
205 | 211 | "Error in memory profiling. "
|
206 |
| - f"Initial free memory {self.init_gpu_memory/GiB_bytes} GiB, " |
207 |
| - f"current free memory {free_gpu_memory/GiB_bytes} GiB. " |
208 |
| - f"This happens when the GPU memory was not properly cleaned up " |
209 |
| - f"before initializing the vLLM instance.") |
210 |
| - |
211 |
| - # Get the peak memory allocation recorded by torch |
212 |
| - peak_torch_memory = torch.cuda.memory_stats( |
213 |
| - )["allocated_bytes.all.peak"] |
214 |
| - |
215 |
| - # Check for any memory left around that may have been allocated on the |
216 |
| - # gpu outside of `torch`. NCCL operations, for example, can use a few |
217 |
| - # GB during a forward pass. |
218 |
| - torch.cuda.empty_cache() |
219 |
| - torch_allocated_bytes = torch.cuda.memory_stats( |
220 |
| - )["allocated_bytes.all.current"] |
221 |
| - |
222 |
| - # Reset after emptying torch cache |
223 |
| - free_gpu_memory = torch.cuda.mem_get_info()[0] |
| 212 | + f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, " |
| 213 | + f"current free memory {GiB(free_gpu_memory)} GiB. " |
| 214 | + "This happens when other processes sharing the same container " |
| 215 | + "release GPU memory while vLLM is profiling during initialization. " |
| 216 | + "To fix this, ensure consistent GPU memory allocation or " |
| 217 | + "isolate vLLM in its own container.") |
| 218 | + available_kv_cache_memory = self.requested_memory \ |
| 219 | + - profile_result.non_kv_cache_memory |
224 | 220 |
|
225 |
| - # Total forward allocation (current) is equal to the diff in free memory |
226 |
| - fwd_alloc_bytes = self.init_gpu_memory - free_gpu_memory |
227 |
| - # We assume current non-torch allocation is equal to peak |
228 |
| - non_torch_alloc_bytes = max(0, fwd_alloc_bytes - torch_allocated_bytes) |
229 |
| - # Total forward allocation (peak) is peak torch + non-torch |
230 |
| - peak_memory = peak_torch_memory + non_torch_alloc_bytes |
231 |
| - |
232 |
| - available_kv_cache_memory = ( |
233 |
| - total_gpu_memory * self.cache_config.gpu_memory_utilization - |
234 |
| - peak_memory) |
235 |
| - |
236 |
| - GiB = lambda b: b / GiB_bytes |
237 | 221 | logger.debug(
|
238 | 222 | "Initial free memory: %.2f GiB, free memory: %.2f GiB, "
|
239 |
| - "total GPU memory: %.2f GiB", GiB(self.init_gpu_memory), |
240 |
| - GiB(free_gpu_memory), GiB(total_gpu_memory)) |
241 |
| - logger.debug( |
242 |
| - "Peak torch memory: %.2f GiB, non-torch forward-pass memory: " |
243 |
| - "%.2f GiB, available KVCache memory: %.2f GiB", |
244 |
| - GiB(peak_torch_memory), GiB(non_torch_alloc_bytes), |
245 |
| - GiB(available_kv_cache_memory)) |
| 223 | + "requested GPU memory: %.2f GiB", |
| 224 | + GiB(self.init_snapshot.free_memory), GiB(free_gpu_memory), |
| 225 | + GiB(self.requested_memory)) |
| 226 | + logger.debug(profile_result) |
| 227 | + logger.info("Available KV cache memory: %.2f GiB", |
| 228 | + GiB(available_kv_cache_memory)) |
| 229 | + gc.collect() |
246 | 230 |
|
247 | 231 | return int(available_kv_cache_memory)
|
248 | 232 |
|
|
0 commit comments