From 2a9bbc10c1950ba1dece9896735efb9ab287bbef Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Wed, 25 Jun 2025 14:19:39 -0700 Subject: [PATCH 1/3] Separate model and weights loading RPC Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 19 +++++++--------- vllm/v1/worker/gpu_worker.py | 36 ++++++++++++++++-------------- vllm/v1/worker/tpu_model_runner.py | 18 +++++++-------- vllm/v1/worker/tpu_worker.py | 3 +++ 4 files changed, 38 insertions(+), 38 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 40639fdf243..c2a821d61b9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1697,17 +1697,9 @@ def load_model(self) -> None: with DeviceMemoryProfiler() as m: # noqa: SIM117 time_before_load = time.perf_counter() model_loader = get_model_loader(self.load_config) - if not hasattr(self, "model"): - logger.info("Loading model from scratch...") - self.model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.model_config) - else: - logger.info( - "Model was already initialized. Loading weights inplace..." - ) - model_loader.load_weights(self.model, - model_config=self.model_config) + logger.info("Loading model from scratch...") + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config) if has_step_pooler(self.model): self.input_batch.logits_processing_needs_token_ids = True if self.lora_config: @@ -1729,6 +1721,11 @@ def load_model(self) -> None: time_after_load - time_before_load) prepare_communication_buffer_for_model(self.model) + def reload_weights(self) -> None: + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model_loader.load_weights(self.model, model_config=self.model_config) + def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index b0f80c70132..c2820bde38f 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -3,6 +3,7 @@ """A GPU worker class.""" import gc import os +from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Optional import torch @@ -112,6 +113,18 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None: buffer.data.copy_(self._sleep_saved_buffers[name].data) self._sleep_saved_buffers = {} + def _maybe_get_memory_pool_context(self, + tag: str) -> AbstractContextManager: + if self.vllm_config.model_config.enable_sleep_mode: + allocator = CuMemAllocator.get_instance() + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(tag=tag) + else: + context = nullcontext() + return context + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks @@ -172,18 +185,13 @@ def init_device(self): # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: - if self.vllm_config.model_config.enable_sleep_mode: - allocator = CuMemAllocator.get_instance() - assert allocator.get_current_usage() == 0, ( - "Sleep mode can only be " - "used for one instance per process.") - context = allocator.use_memory_pool(tag="weights") - else: - from contextlib import nullcontext - context = nullcontext() - with context: + with self._maybe_get_memory_pool_context(tag="weights"): self.model_runner.load_model() + def reload_weights(self) -> None: + with self._maybe_get_memory_pool_context(tag="weights"): + self.model_runner.reload_weights() + @torch.inference_mode() def determine_available_memory(self) -> int: """Profiles the peak memory usage of the model to determine how much @@ -240,13 +248,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" - if self.vllm_config.model_config.enable_sleep_mode: - allocator = CuMemAllocator.get_instance() - context = allocator.use_memory_pool(tag="kv_cache") - else: - from contextlib import nullcontext - context = nullcontext() - with context: + with self._maybe_get_memory_pool_context(tag="kv_cache"): self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 774caa1a3d9..e06110a0201 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -995,16 +995,9 @@ def load_model(self) -> None: else: # model = get_model(vllm_config=self.vllm_config) model_loader = get_model_loader(self.load_config) - if not hasattr(self, "model"): - logger.info("Loading model from scratch...") - model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.model_config) - else: - logger.info("Model was already initialized. \ - Loading weights inplace...") - model_loader.load_weights(self.model, - model_config=self.model_config) + logger.info("Loading model from scratch...") + model = model_loader.load_model(vllm_config=self.vllm_config, + model_config=self.model_config) if self.lora_config is not None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, @@ -1019,6 +1012,11 @@ def load_model(self) -> None: self.model = model self.sampler = TPUSampler() + def reload_weights(self) -> None: + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model_loader.load_weights(self.model, model_config=self.model_config) + @torch.no_grad() def _dummy_run(self, num_tokens: int) -> None: if self.is_multimodal_model: diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 87af8e47670..30e50f0e507 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -248,6 +248,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def load_model(self) -> None: self.model_runner.load_model() + def reload_weights(self) -> None: + self.model_runner.reload_weights() + def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: self.model_runner.capture_model() From 0e88ddfeb7f345b8b0ea9379ceac2bf5149c1cfc Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Wed, 25 Jun 2025 14:31:31 -0700 Subject: [PATCH 2/3] unit test Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> --- tests/v1/worker/test_gpu_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 583a88d8e6e..e79a53956cd 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -443,7 +443,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): assert str(model_runner.get_model().state_dict()) != str( model_runner_2.get_model().state_dict()) model_runner_2.load_config.load_format = original_load_format - model_runner_2.load_model() # Load real weights inplace + model_runner_2.reload_weights() # Load real weights inplace assert str(model_runner.get_model().state_dict()) == str( model_runner_2.get_model().state_dict()) From a8a36fe0bb094f18dcf2d9be14108668502d5902 Mon Sep 17 00:00:00 2001 From: 22quinn <33176974+22quinn@users.noreply.github.com> Date: Wed, 25 Jun 2025 17:45:52 -0700 Subject: [PATCH 3/3] fix ci Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> --- vllm/v1/worker/gpu_worker.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c2820bde38f..22f4990a0bf 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -117,9 +117,10 @@ def _maybe_get_memory_pool_context(self, tag: str) -> AbstractContextManager: if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() - assert allocator.get_current_usage() == 0, ( - "Sleep mode can only be " - "used for one instance per process.") + if tag == "weights": + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") context = allocator.use_memory_pool(tag=tag) else: context = nullcontext()