Skip to content

Commit 9760fd8

Browse files
authored
[Core] Support inplace model weights loading (#18745)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
1 parent b9f61e1 commit 9760fd8

File tree

13 files changed

+249
-297
lines changed

13 files changed

+249
-297
lines changed

tests/tensorizer_loader/test_tensorizer.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import os
55
import pathlib
66
import subprocess
7-
from unittest.mock import MagicMock, patch
87

98
import pytest
109
import torch
@@ -16,7 +15,6 @@
1615
from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
1716
TensorSerializer,
1817
is_vllm_tensorized,
19-
load_with_tensorizer,
2018
open_stream,
2119
tensorize_vllm_model)
2220
# yapf: enable
@@ -61,21 +59,6 @@ def write_keyfile(keyfile_path: str):
6159
f.write(encryption_params.key)
6260

6361

64-
@patch('vllm.model_executor.model_loader.tensorizer.TensorizerAgent')
65-
def test_load_with_tensorizer(mock_agent, tensorizer_config):
66-
mock_linear_method = MagicMock()
67-
mock_agent_instance = mock_agent.return_value
68-
mock_agent_instance.deserialize.return_value = MagicMock()
69-
70-
result = load_with_tensorizer(tensorizer_config,
71-
quant_method=mock_linear_method)
72-
73-
mock_agent.assert_called_once_with(tensorizer_config,
74-
quant_method=mock_linear_method)
75-
mock_agent_instance.deserialize.assert_called_once()
76-
assert result == mock_agent_instance.deserialize.return_value
77-
78-
7962
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
8063
def test_can_deserialize_s3(vllm_runner):
8164
model_ref = "EleutherAI/pythia-1.4b"

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def model_runner():
9494
return runner
9595

9696

97+
model_runner_2 = model_runner
98+
99+
97100
def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
98101
new_reqs = []
99102
num_scheduled_tokens = {}
@@ -366,3 +369,18 @@ def rnd_stride_order():
366369
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
367370
else:
368371
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
372+
373+
374+
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
375+
# In this test, model_runner loads model + weights in one go, while
376+
# model_runner_2 loads dummy weights first then load real weights inplace
377+
model_runner.load_model()
378+
original_load_format = model_runner_2.load_config.load_format
379+
model_runner_2.load_config.load_format = "dummy"
380+
model_runner_2.load_model() # Initial model loading with dummy weights
381+
assert str(model_runner.get_model().state_dict()) != str(
382+
model_runner_2.get_model().state_dict())
383+
model_runner_2.load_config.load_format = original_load_format
384+
model_runner_2.load_model() # Load real weights inplace
385+
assert str(model_runner.get_model().state_dict()) == str(
386+
model_runner_2.get_model().state_dict())

vllm/model_executor/model_loader/base_loader.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
from abc import ABC, abstractmethod
33

4+
import torch
45
import torch.nn as nn
56

67
from vllm.config import LoadConfig, ModelConfig, VllmConfig
8+
from vllm.model_executor.model_loader.utils import (
9+
initialize_model, process_weights_after_loading, set_default_torch_dtype)
710

811

912
class BaseModelLoader(ABC):
@@ -18,7 +21,22 @@ def download_model(self, model_config: ModelConfig) -> None:
1821
raise NotImplementedError
1922

2023
@abstractmethod
21-
def load_model(self, *, vllm_config: VllmConfig,
24+
def load_weights(self, model: nn.Module,
25+
model_config: ModelConfig) -> None:
26+
"""Load weights into a model. This standalone API allows
27+
inplace weights loading for an already-initialized model"""
28+
raise NotImplementedError
29+
30+
def load_model(self, vllm_config: VllmConfig,
2231
model_config: ModelConfig) -> nn.Module:
2332
"""Load a model with the given configurations."""
24-
raise NotImplementedError
33+
device_config = vllm_config.device_config
34+
target_device = torch.device(device_config.device)
35+
with set_default_torch_dtype(model_config.dtype):
36+
with target_device:
37+
model = initialize_model(vllm_config=vllm_config,
38+
model_config=model_config)
39+
# Quantization does not happen in `load_weights` but after it
40+
self.load_weights(model, model_config)
41+
process_weights_after_loading(model, model_config, target_device)
42+
return model.eval()

vllm/model_executor/model_loader/bitsandbytes_loader.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch import nn
1515
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
1616

17-
from vllm.config import LoadConfig, ModelConfig, VllmConfig
17+
from vllm.config import LoadConfig, ModelConfig
1818
from vllm.distributed import (get_tensor_model_parallel_rank,
1919
get_tensor_model_parallel_world_size)
2020
# yapf: enable
@@ -28,7 +28,6 @@
2828
RowParallelLinear)
2929
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
3030
from vllm.model_executor.model_loader.utils import (ParamMapping,
31-
initialize_model,
3231
set_default_torch_dtype)
3332
from vllm.model_executor.model_loader.weight_utils import (
3433
download_safetensors_index_file_from_hf, download_weights_from_hf,
@@ -408,8 +407,7 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None:
408407
), "vllm currently does not support BNB quantization for"
409408
f" {type(model).__name__}"
410409

411-
def _load_weights(self, model_config: ModelConfig,
412-
model: nn.Module) -> None:
410+
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
413411
if not hasattr(model, "load_weights"):
414412
raise AttributeError(
415413
"The required method 'load_weights' is not defined in class"
@@ -568,15 +566,3 @@ def _load_weights(self, model_config: ModelConfig,
568566

569567
def download_model(self, model_config: ModelConfig) -> None:
570568
self._prepare_weights(model_config.model, model_config.revision)
571-
572-
def load_model(self, vllm_config: VllmConfig,
573-
model_config: ModelConfig) -> nn.Module:
574-
device_config = vllm_config.device_config
575-
with set_default_torch_dtype(model_config.dtype):
576-
with torch.device(device_config.device):
577-
578-
model = initialize_model(vllm_config=vllm_config)
579-
580-
self._load_weights(model_config, model)
581-
582-
return model.eval()

vllm/model_executor/model_loader/default_loader.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
1313

1414
from vllm import envs
15-
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
15+
from vllm.config import LoadConfig, LoadFormat, ModelConfig
1616
from vllm.logger import init_logger
1717
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
18-
from vllm.model_executor.model_loader.utils import (
19-
initialize_model, process_weights_after_loading, set_default_torch_dtype)
2018
from vllm.model_executor.model_loader.weight_utils import (
2119
download_safetensors_index_file_from_hf, download_weights_from_hf,
2220
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
@@ -264,32 +262,20 @@ def download_model(self, model_config: ModelConfig) -> None:
264262
fall_back_to_pt=True,
265263
allow_patterns_overrides=None)
266264

267-
def load_model(self, vllm_config: VllmConfig,
268-
model_config: ModelConfig) -> nn.Module:
269-
device_config = vllm_config.device_config
270-
target_device = torch.device(device_config.device)
271-
with set_default_torch_dtype(model_config.dtype):
272-
with target_device:
273-
model = initialize_model(vllm_config=vllm_config,
274-
model_config=model_config)
275-
276-
weights_to_load = {name for name, _ in model.named_parameters()}
277-
loaded_weights = model.load_weights(
278-
self.get_all_weights(model_config, model))
279-
self.counter_after_loading_weights = time.perf_counter()
280-
logger.info(
281-
"Loading weights took %.2f seconds",
282-
self.counter_after_loading_weights -
283-
self.counter_before_loading_weights)
284-
# We only enable strict check for non-quantized models
285-
# that have loaded weights tracking currently.
286-
if model_config.quantization is None and loaded_weights is not None:
287-
weights_not_loaded = weights_to_load - loaded_weights
288-
if weights_not_loaded:
289-
raise ValueError(
290-
"Following weights were not initialized from "
291-
f"checkpoint: {weights_not_loaded}")
292-
293-
process_weights_after_loading(model, model_config, target_device)
294-
295-
return model.eval()
265+
def load_weights(self, model: nn.Module,
266+
model_config: ModelConfig) -> None:
267+
weights_to_load = {name for name, _ in model.named_parameters()}
268+
loaded_weights = model.load_weights(
269+
self.get_all_weights(model_config, model))
270+
self.counter_after_loading_weights = time.perf_counter()
271+
logger.info(
272+
"Loading weights took %.2f seconds",
273+
self.counter_after_loading_weights -
274+
self.counter_before_loading_weights)
275+
# We only enable strict check for non-quantized models
276+
# that have loaded weights tracking currently.
277+
if model_config.quantization is None and loaded_weights is not None:
278+
weights_not_loaded = weights_to_load - loaded_weights
279+
if weights_not_loaded:
280+
raise ValueError("Following weights were not initialized from "
281+
f"checkpoint: {weights_not_loaded}")
Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
import torch
32
import torch.nn as nn
43

5-
from vllm.config import LoadConfig, ModelConfig, VllmConfig
4+
from vllm.config import LoadConfig, ModelConfig
65
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
7-
from vllm.model_executor.model_loader.utils import (
8-
initialize_model, process_weights_after_loading, set_default_torch_dtype)
96
from vllm.model_executor.model_loader.weight_utils import (
107
initialize_dummy_weights)
118

@@ -22,16 +19,8 @@ def __init__(self, load_config: LoadConfig):
2219
def download_model(self, model_config: ModelConfig) -> None:
2320
pass # Nothing to download
2421

25-
def load_model(self, vllm_config: VllmConfig,
26-
model_config: ModelConfig) -> nn.Module:
27-
device_config = vllm_config.device_config
28-
target_device = torch.device(device_config.device)
29-
with set_default_torch_dtype(model_config.dtype):
30-
with target_device:
31-
model = initialize_model(vllm_config=vllm_config)
32-
# NOTE(woosuk): For accurate performance evaluation, we assign
33-
# random values to the weights.
34-
initialize_dummy_weights(model)
35-
36-
process_weights_after_loading(model, model_config, target_device)
37-
return model.eval()
22+
def load_weights(self, model: nn.Module,
23+
model_config: ModelConfig) -> None:
24+
# NOTE(woosuk): For accurate performance evaluation, we assign
25+
# random values to the weights.
26+
initialize_dummy_weights(model)

vllm/model_executor/model_loader/gguf_loader.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ def _get_weights_iterator(
9292
def download_model(self, model_config: ModelConfig) -> None:
9393
self._prepare_weights(model_config.model)
9494

95+
def load_weights(self, model: nn.Module,
96+
model_config: ModelConfig) -> None:
97+
local_model_path = self._prepare_weights(model_config.model)
98+
gguf_weights_map = self._get_gguf_weights_map(model_config)
99+
model.load_weights(
100+
self._get_weights_iterator(local_model_path, gguf_weights_map))
101+
95102
def load_model(self, vllm_config: VllmConfig,
96103
model_config: ModelConfig) -> nn.Module:
97104
device_config = vllm_config.device_config
@@ -106,8 +113,7 @@ def load_model(self, vllm_config: VllmConfig,
106113
with set_default_torch_dtype(model_config.dtype):
107114
with target_device:
108115
model = initialize_model(vllm_config=vllm_config)
109-
model.load_weights(
110-
self._get_weights_iterator(local_model_path, gguf_weights_map))
116+
self.load_weights(model, model_config)
111117

112118
process_weights_after_loading(model, model_config, target_device)
113119
return model

vllm/model_executor/model_loader/runai_streamer_loader.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99
from torch import nn
1010
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
1111

12-
from vllm.config import LoadConfig, ModelConfig, VllmConfig
12+
from vllm.config import LoadConfig, ModelConfig
1313
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
14-
from vllm.model_executor.model_loader.utils import (
15-
initialize_model, process_weights_after_loading, set_default_torch_dtype)
1614
from vllm.model_executor.model_loader.weight_utils import (
1715
download_safetensors_index_file_from_hf, download_weights_from_hf,
1816
runai_safetensors_weights_iterator)
@@ -100,21 +98,11 @@ def download_model(self, model_config: ModelConfig) -> None:
10098
"""Download model if necessary"""
10199
self._prepare_weights(model_config.model, model_config.revision)
102100

103-
def load_model(self, vllm_config: VllmConfig,
104-
model_config: ModelConfig) -> nn.Module:
105-
"""Perform streaming of the model to destination"""
106-
device_config = vllm_config.device_config
107-
target_device = torch.device(device_config.device)
108-
with set_default_torch_dtype(model_config.dtype):
109-
with target_device:
110-
model = initialize_model(vllm_config=vllm_config)
111-
112-
model_weights = model_config.model
113-
if hasattr(model_config, "model_weights"):
114-
model_weights = model_config.model_weights
115-
model.load_weights(
116-
self._get_weights_iterator(model_weights,
117-
model_config.revision))
118-
119-
process_weights_after_loading(model, model_config, target_device)
120-
return model.eval()
101+
def load_weights(self, model: nn.Module,
102+
model_config: ModelConfig) -> None:
103+
"""Load weights into a model."""
104+
model_weights = model_config.model
105+
if hasattr(model_config, "model_weights"):
106+
model_weights = model_config.model_weights
107+
model.load_weights(
108+
self._get_weights_iterator(model_weights, model_config.revision))

0 commit comments

Comments
 (0)