Skip to content

Commit 4ac9c33

Browse files
authored
[Bugfix] Fix handling of Tensorizer arguments for LoadConfig (#20643)
Signed-off-by: Sanger Steel <sangersteel@gmail.com>
1 parent efe73d0 commit 4ac9c33

File tree

4 files changed

+21
-52
lines changed

4 files changed

+21
-52
lines changed

tests/tensorizer_loader/test_tensorizer.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -103,25 +103,6 @@ def write_keyfile(keyfile_path: str):
103103
f.write(encryption_params.key)
104104

105105

106-
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
107-
def test_can_deserialize_s3(vllm_runner):
108-
model_ref = "EleutherAI/pythia-1.4b"
109-
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
110-
111-
with vllm_runner(model_ref,
112-
load_format="tensorizer",
113-
model_loader_extra_config=TensorizerConfig(
114-
tensorizer_uri=tensorized_path,
115-
num_readers=1,
116-
s3_endpoint="object.ord1.coreweave.com",
117-
)) as loaded_hf_model:
118-
deserialized_outputs = loaded_hf_model.generate(
119-
prompts, sampling_params)
120-
# noqa: E501
121-
122-
assert deserialized_outputs
123-
124-
125106
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
126107
def test_deserialized_encrypted_vllm_model_has_same_outputs(
127108
model_ref, vllm_runner, tmp_path, model_path):

vllm/engine/arg_utils.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,41 +1003,27 @@ def create_model_config(self) -> ModelConfig:
10031003
override_attention_dtype=self.override_attention_dtype,
10041004
)
10051005

1006-
def valid_tensorizer_config_provided(self) -> bool:
1007-
"""
1008-
Checks if a parseable TensorizerConfig was passed to
1009-
self.model_loader_extra_config. It first checks if the config passed
1010-
is a dict or a TensorizerConfig object directly, and if the latter is
1011-
true (by checking that the object has TensorizerConfig's
1012-
.to_serializable() method), converts it in to a serializable dict
1013-
format
1014-
"""
1015-
if self.model_loader_extra_config:
1016-
if hasattr(self.model_loader_extra_config, "to_serializable"):
1017-
self.model_loader_extra_config = (
1018-
self.model_loader_extra_config.to_serializable())
1019-
for allowed_to_pass in ["tensorizer_uri", "tensorizer_dir"]:
1020-
try:
1021-
self.model_loader_extra_config[allowed_to_pass]
1022-
return False
1023-
except KeyError:
1024-
pass
1025-
return True
1006+
def validate_tensorizer_args(self):
1007+
from vllm.model_executor.model_loader.tensorizer import (
1008+
TensorizerConfig)
1009+
for key in self.model_loader_extra_config:
1010+
if key in TensorizerConfig._fields:
1011+
self.model_loader_extra_config["tensorizer_config"][
1012+
key] = self.model_loader_extra_config[key]
10261013

10271014
def create_load_config(self) -> LoadConfig:
10281015

10291016
if self.quantization == "bitsandbytes":
10301017
self.load_format = "bitsandbytes"
10311018

1032-
if (self.load_format == "tensorizer"
1033-
and self.valid_tensorizer_config_provided()):
1034-
logger.info("Inferring Tensorizer args from %s", self.model)
1035-
self.model_loader_extra_config = {"tensorizer_dir": self.model}
1036-
else:
1037-
logger.info(
1038-
"Using Tensorizer args from --model-loader-extra-config. "
1039-
"Note that you can now simply pass the S3 directory in the "
1040-
"model tag instead of providing the JSON string.")
1019+
if self.load_format == "tensorizer":
1020+
if hasattr(self.model_loader_extra_config, "to_serializable"):
1021+
self.model_loader_extra_config = (
1022+
self.model_loader_extra_config.to_serializable())
1023+
self.model_loader_extra_config["tensorizer_config"] = {}
1024+
self.model_loader_extra_config["tensorizer_config"][
1025+
"tensorizer_dir"] = self.model
1026+
self.validate_tensorizer_args()
10411027

10421028
return LoadConfig(
10431029
load_format=self.load_format,

vllm/model_executor/model_loader/tensorizer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,9 +223,11 @@ def __post_init__(self):
223223
and re.search(r'%0\dd', self.tensorizer_uri) is not None
224224

225225
if self.tensorizer_dir and self.tensorizer_uri:
226-
raise ValueError(
227-
"Either tensorizer_dir or tensorizer_uri must be provided, "
228-
"not both.")
226+
logger.warning_once(
227+
"Provided both tensorizer_dir and tensorizer_uri. "
228+
"Inferring tensorizer_dir from tensorizer_uri as the "
229+
"latter takes precedence.")
230+
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
229231
if self.tensorizer_dir and self.lora_dir:
230232
raise ValueError(
231233
"Only one of tensorizer_dir or lora_dir may be specified. "

vllm/model_executor/model_loader/tensorizer_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, load_config: LoadConfig):
4343
else:
4444
validate_config(load_config.model_loader_extra_config)
4545
self.tensorizer_config = TensorizerConfig(
46-
**load_config.model_loader_extra_config)
46+
**load_config.model_loader_extra_config["tensorizer_config"])
4747

4848
def _verify_config(self, model_config: ModelConfig,
4949
parallel_config: ParallelConfig):

0 commit comments

Comments
 (0)