diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 65c177f8c5a..89f256fd6ac 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -30,6 +30,7 @@ # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.platforms import current_platform from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, SupportsPP) @@ -639,6 +640,11 @@ def prepare_attn_masks( **kwargs, ): kwargs["has_images"] = True + + if current_platform.is_hpu(): + input_ids = input_ids.flatten() + positions = positions.flatten() + # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. # This is a HACK. Fix this. start_idices = (positions == 0).cpu().nonzero() diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 4c36e947a88..3a8acfc7278 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -1269,6 +1269,17 @@ def add_vision_buckets_to_mrope_models(self): model = self.get_model() model.vision_buckets = VisionBuckets() + def _get_position_pad(self) -> int: + """ + For gemma3 models, + due to the Hack in Gemma3ForConditionalGeneration::prepare_attn_masks, + '0' can't be used as pad for input position tensor. + In case, it might have '0's for bucketing, those '0' will be counted as + new sequence in the prepare_attn_masks() which is wrong. + """ + model_type = getattr(self.model_config.hf_config, 'model_type', '') + return -1 if model_type == 'gemma3' else 0 + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1506,11 +1517,11 @@ def _prepare_prompt( make_mrope_positions_tensor_with_pad(input_positions=input_positions, input_mrope_positions=input_mrope_positions, max_prompt_len=max_prompt_len, - pad=0) + pad=self._get_position_pad()) else: input_positions = make_cpu_tensor(input_positions, max_len=max_prompt_len, - pad=0, + pad=self._get_position_pad(), dtype=torch.long, flat=self.use_merged_prefill)