Skip to content

Commit 5f02426

Browse files
committed
Optimized merge_multimodal_embeddings on Gaudi
Signed-off-by: gyou2021 <ganmei.you@intel.com>
1 parent 522b46e commit 5f02426

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

vllm/model_executor/models/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,14 @@ def merge_multimodal_embeddings(
476476
Note:
477477
This updates ``inputs_embeds`` in place.
478478
"""
479+
if current_platform.is_hpu():
480+
return _hpu_merge_multimodal_embeddings(
481+
input_ids,
482+
inputs_embeds,
483+
multimodal_embeddings,
484+
placeholder_token_id,
485+
)
486+
479487
if isinstance(placeholder_token_id, list):
480488
placeholder_token_id = torch.tensor(placeholder_token_id,
481489
device=input_ids.device)
@@ -737,3 +745,27 @@ def fast_topk(values, topk, dim):
737745
else:
738746
# Use topk for efficiency with larger k values
739747
return torch.topk(values, topk, dim=dim)
748+
749+
def _hpu_merge_multimodal_embeddings(
750+
input_ids: torch.Tensor,
751+
inputs_embeds: torch.Tensor,
752+
multimodal_embeddings: NestedTensors,
753+
placeholder_token_id: torch.tensor,
754+
) -> torch.Tensor:
755+
"""
756+
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
757+
positions in ``inputs_embeds`` corresponding to placeholder tokens in
758+
``input_ids``.
759+
merge_multimodal_embeddings on HPU to avoid dynamicity.
760+
Note:
761+
This updates ``inputs_embeds`` in place.
762+
"""
763+
batch_size, seq_length, hidden_size = inputs_embeds.shape
764+
inputs_embeds = inputs_embeds.reshape(-1, hidden_size)
765+
multimodal_embeddings = multimodal_embeddings.reshape(-1, hidden_size)
766+
placeholder_token_id = torch.tensor(placeholder_token_id,
767+
device=input_ids.device)
768+
mask = torch.isin(input_ids.reshape(-1), placeholder_token_id)
769+
inputs_embeds.index_put_((mask, ), multimodal_embeddings)
770+
inputs_embeds = inputs_embeds.reshape(batch_size, seq_length, hidden_size)
771+
return inputs_embeds

0 commit comments

Comments
 (0)