Skip to content

Commit 4978811

Browse files
authored
fix bug in merge_multimodal_embeddings on HPU (HabanaAI#1436)
This PR is aim to fix bugs in HabanaAI#1433 HabanaAI#1433 is not working on following cases : 1)placeholder_token_id is a list 2)multimodal_embeddings is not tensor Tested models: PT_HPU_LAZY_MODE=1 VLLM_SKIP_WARMUP=true python examples/offline_inference/vision_language.py -m glm4v PT_HPU_LAZY_MODE=1 VLLM_SKIP_WARMUP=true python examples/offline_inference/vision_language.py -m qwen_vl PT_HPU_LAZY_MODE=1 VLLM_SKIP_WARMUP=true python examples/offline_inference/vision_language.py -m qwen2_vl PT_HPU_LAZY_MODE=1 VLLM_SKIP_WARMUP=true python examples/offline_inference/vision_language.py -m qwen2_5_vl PT_HPU_LAZY_MODE=1 VLLM_SKIP_WARMUP=true python examples/offline_inference/vision_language.py -m qwen2_5_omni
1 parent b8ebf29 commit 4978811

File tree

2 files changed

+14
-39
lines changed

2 files changed

+14
-39
lines changed

vllm/model_executor/models/qwen_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,7 @@ def get_input_embeddings(
774774
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
775775

776776
if multimodal_embeddings is not None:
777-
inputs_embeds = self._merge_multimodal_embeddings(
777+
inputs_embeds = merge_multimodal_embeddings(
778778
input_ids, inputs_embeds, multimodal_embeddings,
779779
self.transformer.visual.image_pad_id)
780780

vllm/model_executor/models/utils.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,6 @@ def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
343343
if isinstance(embeddings, torch.Tensor):
344344
# Flatten all but the last dimension.
345345
return embeddings.flatten(0, -2)
346-
347346
return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
348347

349348

@@ -391,8 +390,19 @@ def _merge_multimodal_embeddings(
391390
"""
392391
# skip check for HPU, the number of tokens is a cpu fallback during HPU lazy
393392
if current_platform.is_hpu():
394-
flattened = _flatten_embeddings(multimodal_embeddings)
395-
inputs_embeds[is_multimodal] = flattened
393+
394+
if isinstance(multimodal_embeddings, torch.Tensor):
395+
is_multimodal = is_multimodal.reshape(-1)
396+
batch_size, seq_length, hidden_size = inputs_embeds.shape
397+
inputs_embeds = inputs_embeds.reshape(-1, hidden_size)
398+
flattened = multimodal_embeddings.reshape(-1, hidden_size)
399+
inputs_embeds[is_multimodal] = flattened
400+
inputs_embeds = inputs_embeds.reshape(batch_size, seq_length,
401+
hidden_size)
402+
else:
403+
flattened = _flatten_embeddings(multimodal_embeddings)
404+
inputs_embeds[is_multimodal] = flattened
405+
396406
return inputs_embeds
397407

398408
num_expected_tokens = is_multimodal.sum().item()
@@ -476,14 +486,6 @@ def merge_multimodal_embeddings(
476486
Note:
477487
This updates ``inputs_embeds`` in place.
478488
"""
479-
if current_platform.is_hpu():
480-
return _hpu_merge_multimodal_embeddings(
481-
input_ids,
482-
inputs_embeds,
483-
multimodal_embeddings,
484-
placeholder_token_id,
485-
)
486-
487489
if isinstance(placeholder_token_id, list):
488490
placeholder_token_id = torch.tensor(placeholder_token_id,
489491
device=input_ids.device)
@@ -492,7 +494,6 @@ def merge_multimodal_embeddings(
492494
torch.isin(input_ids, placeholder_token_id),
493495
multimodal_embeddings,
494496
)
495-
496497
return _merge_multimodal_embeddings(
497498
inputs_embeds,
498499
(input_ids == placeholder_token_id),
@@ -712,7 +713,6 @@ def extract_layer_index(layer_name: str) -> int:
712713
" only contain one integer")
713714
return int_vals[0]
714715

715-
716716
def get_input_mask(hidden_states: torch.Tensor,
717717
valid_len: torch.Tensor) -> torch.Tensor:
718718
"""
@@ -727,7 +727,6 @@ def get_input_mask(hidden_states: torch.Tensor,
727727
mask = mask.to(hidden_states.dtype)
728728
return mask
729729

730-
731730
def cast_overflow_tensors(
732731
tensors: torch.Tensor,
733732
offset: float = 1000,
@@ -745,27 +744,3 @@ def fast_topk(values, topk, dim):
745744
else:
746745
# Use topk for efficiency with larger k values
747746
return torch.topk(values, topk, dim=dim)
748-
749-
def _hpu_merge_multimodal_embeddings(
750-
input_ids: torch.Tensor,
751-
inputs_embeds: torch.Tensor,
752-
multimodal_embeddings: NestedTensors,
753-
placeholder_token_id: torch.tensor,
754-
) -> torch.Tensor:
755-
"""
756-
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
757-
positions in ``inputs_embeds`` corresponding to placeholder tokens in
758-
``input_ids``.
759-
merge_multimodal_embeddings on HPU to avoid dynamicity.
760-
Note:
761-
This updates ``inputs_embeds`` in place.
762-
"""
763-
batch_size, seq_length, hidden_size = inputs_embeds.shape
764-
inputs_embeds = inputs_embeds.reshape(-1, hidden_size)
765-
multimodal_embeddings = multimodal_embeddings.reshape(-1, hidden_size)
766-
placeholder_token_id = torch.tensor(placeholder_token_id,
767-
device=input_ids.device)
768-
mask = torch.isin(input_ids.reshape(-1), placeholder_token_id)
769-
inputs_embeds.index_put_((mask, ), multimodal_embeddings)
770-
inputs_embeds = inputs_embeds.reshape(batch_size, seq_length, hidden_size)
771-
return inputs_embeds

0 commit comments

Comments
 (0)