Skip to content

Commit fdfd409

Browse files
authored
[TPU][Core]Make load weight exceed hbm error more instructive for customers (#20644)
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent ffbcc9e commit fdfd409

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,26 +1128,33 @@ def load_model(self) -> None:
11281128
"vllm.model_executor.layers.vocab_parallel_embedding."
11291129
"get_tensor_model_parallel_rank",
11301130
return_value=xm_tp_rank):
1131-
if self.use_spmd:
1132-
tpu_loader = TPUModelLoader(
1133-
load_config=self.vllm_config.load_config)
1134-
model = tpu_loader.load_model(
1135-
vllm_config=self.vllm_config,
1136-
model_config=self.vllm_config.model_config,
1137-
mesh=self.mesh)
1138-
else:
1139-
# model = get_model(vllm_config=self.vllm_config)
1140-
model_loader = get_model_loader(self.load_config)
1141-
if not hasattr(self, "model"):
1142-
logger.info("Loading model from scratch...")
1143-
model = model_loader.load_model(
1131+
try:
1132+
if self.use_spmd:
1133+
tpu_loader = TPUModelLoader(
1134+
load_config=self.vllm_config.load_config)
1135+
model = tpu_loader.load_model(
11441136
vllm_config=self.vllm_config,
1145-
model_config=self.model_config)
1137+
model_config=self.vllm_config.model_config,
1138+
mesh=self.mesh)
11461139
else:
1147-
logger.info("Model was already initialized. \
1148-
Loading weights inplace...")
1149-
model_loader.load_weights(self.model,
1150-
model_config=self.model_config)
1140+
model_loader = get_model_loader(self.load_config)
1141+
if not hasattr(self, "model"):
1142+
logger.info("Loading model from scratch...")
1143+
model = model_loader.load_model(
1144+
vllm_config=self.vllm_config,
1145+
model_config=self.model_config)
1146+
else:
1147+
logger.info("Model was already initialized. \
1148+
Loading weights inplace...")
1149+
model_loader.load_weights(
1150+
self.model, model_config=self.model_config)
1151+
except RuntimeError as e:
1152+
raise RuntimeError(
1153+
f"Unable to load model, a likely reason is the model is "
1154+
"too large for the current device's HBM memory. "
1155+
"Consider switching to a smaller model "
1156+
"or sharding the weights on more chips. "
1157+
f"See the detailed error: {e}") from e
11511158
if self.lora_config is not None:
11521159
model = self.load_lora_model(model, self.model_config,
11531160
self.scheduler_config,

0 commit comments

Comments
 (0)