Skip to content

Commit 54b9ca3

Browse files
committed
Merge remote-tracking branch 'upstream/aice/v1.21.0' into gyou/aice/v1.21.0/qwen3
2 parents 5f02426 + 4978811 commit 54b9ca3

File tree

5 files changed

+18
-8
lines changed

5 files changed

+18
-8
lines changed

vllm/model_executor/layers/pooler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
241241
dimensions_list = [
242242
pooling_param.dimensions
243243
for _, pooling_param in pooling_metadata.seq_groups
244+
if pooling_param is not None
244245
]
245246
if any(d is not None for d in dimensions_list):
246247
# change the output dimension

vllm/model_executor/layers/vocab_parallel_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
395395
padded_weight = torch.cat([
396396
loaded_weight,
397397
torch.zeros(param.shape[0] - loaded_weight.shape[0],
398-
*loaded_weight.shape[1:])
398+
*loaded_weight.shape[1:], device=loaded_weight.device)
399399
])
400400
else:
401401
padded_weight = loaded_weight

vllm/model_executor/models/qwen_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ def get_input_embeddings(
774774
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
775775

776776
if multimodal_embeddings is not None:
777-
inputs_embeds = self._merge_multimodal_embeddings(
777+
inputs_embeds = merge_multimodal_embeddings(
778778
input_ids, inputs_embeds, multimodal_embeddings,
779779
self.transformer.visual.image_pad_id)
780780

vllm/model_executor/models/utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,6 @@ def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
343343
if isinstance(embeddings, torch.Tensor):
344344
# Flatten all but the last dimension.
345345
return embeddings.flatten(0, -2)
346-
347346
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
348347

349348

@@ -391,8 +390,19 @@ def _merge_multimodal_embeddings(
391390
"""
392391
# skip check for HPU, the number of tokens is a cpu fallback during HPU lazy
393392
if current_platform.is_hpu():
394-
flattened = _flatten_embeddings(multimodal_embeddings)
395-
inputs_embeds[is_multimodal] = flattened
393+
394+
if isinstance(multimodal_embeddings, torch.Tensor):
395+
is_multimodal = is_multimodal.reshape(-1)
396+
batch_size, seq_length, hidden_size = inputs_embeds.shape
397+
inputs_embeds = inputs_embeds.reshape(-1, hidden_size)
398+
flattened = multimodal_embeddings.reshape(-1, hidden_size)
399+
inputs_embeds[is_multimodal] = flattened
400+
inputs_embeds = inputs_embeds.reshape(batch_size, seq_length,
401+
hidden_size)
402+
else:
403+
flattened = _flatten_embeddings(multimodal_embeddings)
404+
inputs_embeds[is_multimodal] = flattened
405+
396406
return inputs_embeds
397407

398408
num_expected_tokens = is_multimodal.sum().item()
@@ -492,7 +502,6 @@ def merge_multimodal_embeddings(
492502
torch.isin(input_ids, placeholder_token_id),
493503
multimodal_embeddings,
494504
)
495-
496505
return _merge_multimodal_embeddings(
497506
inputs_embeds,
498507
(input_ids == placeholder_token_id),
@@ -712,7 +721,6 @@ def extract_layer_index(layer_name: str) -> int:
712721
" only contain one integer")
713722
return int_vals[0]
714723

715-
716724
def get_input_mask(hidden_states: torch.Tensor,
717725
valid_len: torch.Tensor) -> torch.Tensor:
718726
"""
@@ -727,7 +735,6 @@ def get_input_mask(hidden_states: torch.Tensor,
727735
mask = mask.to(hidden_states.dtype)
728736
return mask
729737

730-
731738
def cast_overflow_tensors(
732739
tensors: torch.Tensor,
733740
offset: float = 1000,

vllm/worker/hpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,8 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
396396
len_mask_v = len_mask.view(batch_size, 1, seq_len, 1)
397397
mask = attn_mask.logical_or(len_mask).logical_or(len_mask_v)
398398
off_value = -3E38 #small number, avoid nan and overflow
399+
if dtype == torch.float16:
400+
off_value = -63000 # a small value close to float16.min
399401
else:
400402
mask = attn_mask.logical_or(
401403
len_mask) #no need for len_mask_v as decode overwrites it

0 commit comments

Comments
 (0)