@@ -476,6 +476,14 @@ def merge_multimodal_embeddings(
476
476
Note:
477
477
This updates ``inputs_embeds`` in place.
478
478
"""
479
+ if current_platform .is_hpu ():
480
+ return _hpu_merge_multimodal_embeddings (
481
+ input_ids ,
482
+ inputs_embeds ,
483
+ multimodal_embeddings ,
484
+ placeholder_token_id ,
485
+ )
486
+
479
487
if isinstance (placeholder_token_id , list ):
480
488
placeholder_token_id = torch .tensor (placeholder_token_id ,
481
489
device = input_ids .device )
@@ -737,3 +745,27 @@ def fast_topk(values, topk, dim):
737
745
else :
738
746
# Use topk for efficiency with larger k values
739
747
return torch .topk (values , topk , dim = dim )
748
+
749
+ def _hpu_merge_multimodal_embeddings (
750
+ input_ids : torch .Tensor ,
751
+ inputs_embeds : torch .Tensor ,
752
+ multimodal_embeddings : NestedTensors ,
753
+ placeholder_token_id : torch .tensor ,
754
+ ) -> torch .Tensor :
755
+ """
756
+ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
757
+ positions in ``inputs_embeds`` corresponding to placeholder tokens in
758
+ ``input_ids``.
759
+ merge_multimodal_embeddings on HPU to avoid dynamicity.
760
+ Note:
761
+ This updates ``inputs_embeds`` in place.
762
+ """
763
+ batch_size , seq_length , hidden_size = inputs_embeds .shape
764
+ inputs_embeds = inputs_embeds .reshape (- 1 , hidden_size )
765
+ multimodal_embeddings = multimodal_embeddings .reshape (- 1 , hidden_size )
766
+ placeholder_token_id = torch .tensor (placeholder_token_id ,
767
+ device = input_ids .device )
768
+ mask = torch .isin (input_ids .reshape (- 1 ), placeholder_token_id )
769
+ inputs_embeds .index_put_ ((mask , ), multimodal_embeddings )
770
+ inputs_embeds = inputs_embeds .reshape (batch_size , seq_length , hidden_size )
771
+ return inputs_embeds
0 commit comments