diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 583a88d8e6e..e79a53956cd 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -443,7 +443,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): assert str(model_runner.get_model().state_dict()) != str( model_runner_2.get_model().state_dict()) model_runner_2.load_config.load_format = original_load_format - model_runner_2.load_model() # Load real weights inplace + model_runner_2.reload_weights() # Load real weights inplace assert str(model_runner.get_model().state_dict()) == str( model_runner_2.get_model().state_dict()) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3c9de572040..fd5b614ab6b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1726,17 +1726,9 @@ def load_model(self) -> None: with DeviceMemoryProfiler() as m: # noqa: SIM117 time_before_load = time.perf_counter() model_loader = get_model_loader(self.load_config) - if not hasattr(self, "model"): - logger.info("Loading model from scratch...") - self.model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.model_config) - else: - logger.info( - "Model was already initialized. Loading weights inplace..." - ) - model_loader.load_weights(self.model, - model_config=self.model_config) + logger.info("Loading model from scratch...") + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config) if has_step_pooler(self.model): self.input_batch.logits_processing_needs_token_ids = True if self.lora_config: @@ -1768,6 +1760,11 @@ def load_model(self) -> None: self.parallel_config, ) + def reload_weights(self) -> None: + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model_loader.load_weights(self.model, model_config=self.model_config) + def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 9e7e44d0686..1e94a6e1fb2 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -3,6 +3,7 @@ """A GPU worker class.""" import gc import os +from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Optional import torch @@ -112,6 +113,19 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None: buffer.data.copy_(self._sleep_saved_buffers[name].data) self._sleep_saved_buffers = {} + def _maybe_get_memory_pool_context(self, + tag: str) -> AbstractContextManager: + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CuMemAllocator.get_instance() + if tag == "weights": + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(tag=tag) + else: + context = nullcontext() + return context + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks @@ -172,18 +186,13 @@ def init_device(self): # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: - if self.vllm_config.model_config.enable_sleep_mode: - allocator = CuMemAllocator.get_instance() - assert allocator.get_current_usage() == 0, ( - "Sleep mode can only be " - "used for one instance per process.") - context = allocator.use_memory_pool(tag="weights") - else: - from contextlib import nullcontext - context = nullcontext() - with context: + with self._maybe_get_memory_pool_context(tag="weights"): self.model_runner.load_model() + def reload_weights(self) -> None: + with self._maybe_get_memory_pool_context(tag="weights"): + self.model_runner.reload_weights() + @torch.inference_mode() def determine_available_memory(self) -> int: """Profiles the peak memory usage of the model to determine how much @@ -240,13 +249,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" - if self.vllm_config.model_config.enable_sleep_mode: - allocator = CuMemAllocator.get_instance() - context = allocator.use_memory_pool(tag="kv_cache") - else: - from contextlib import nullcontext - context = nullcontext() - with context: + with self._maybe_get_memory_pool_context(tag="kv_cache"): self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index bc334419c4c..d501e024a52 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1131,16 +1131,9 @@ def load_model(self) -> None: else: # model = get_model(vllm_config=self.vllm_config) model_loader = get_model_loader(self.load_config) - if not hasattr(self, "model"): - logger.info("Loading model from scratch...") - model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.model_config) - else: - logger.info("Model was already initialized. \ - Loading weights inplace...") - model_loader.load_weights(self.model, - model_config=self.model_config) + logger.info("Loading model from scratch...") + model = model_loader.load_model(vllm_config=self.vllm_config, + model_config=self.model_config) if self.lora_config is not None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, @@ -1155,6 +1148,11 @@ def load_model(self) -> None: self.model = model self.sampler = TPUSampler() + def reload_weights(self) -> None: + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model_loader.load_weights(self.model, model_config=self.model_config) + @torch.no_grad() def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index a64ce881fe3..2f4eb485614 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -259,6 +259,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def load_model(self) -> None: self.model_runner.load_model() + def reload_weights(self) -> None: + self.model_runner.reload_weights() + def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: self.model_runner.capture_model()