diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 9a239c94bdf..60b203df87a 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -135,6 +135,12 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): cross_block_groups: Optional[torch.Tensor] = None cross_block_usage: Optional[torch.Tensor] = None cross_attn_bias: Optional[torch.Tensor] = None + window_block_list: Optional[torch.Tensor] = None + window_slot_mapping: Optional[torch.Tensor] = None + window_block_mapping: Optional[torch.Tensor] = None + window_block_groups: Optional[torch.Tensor] = None + window_block_usage: Optional[torch.Tensor] = None + window_attn_bias: Optional[torch.Tensor] = None @dataclass @@ -542,6 +548,18 @@ def forward( block_list = attn_metadata.block_list if attn_metadata \ and attn_metadata.block_list is not None else None + common_args = self.common_attention_args(block_list, key_cache, + value_cache, + attn_metadata.block_size) + + #TODO: Ideally we want to create this sliding_window_bias mask only + #once in the model_runner or gemma model file then only retrieve here. + if self.sliding_window: + attn_bias = _make_sliding_window_bias( + batch_size, seq_len, attn_metadata.seq_lens_tensor, + self.sliding_window, query.dtype) + common_args['pad'] = 'left' + out = ops.prompt_attention( impl=self.prefill_impl, query=query.view(query_shape), @@ -551,12 +569,16 @@ def forward( attn_bias=attn_bias, position_bias=position_bias, valid_seq_lengths=attn_metadata.seq_lens_tensor, - **self.common_attention_args(block_list, key_cache, - value_cache, - attn_metadata.block_size)) + **common_args) + output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. + block_list = attn_metadata.block_list if not self.sliding_window else attn_metadata.window_block_list + block_groups = attn_metadata.block_groups if not self.sliding_window else attn_metadata.window_block_groups + block_mapping = attn_metadata.block_mapping if not self.sliding_window else attn_metadata.window_block_mapping + attn_bias = attn_metadata.attn_bias if not self.sliding_window else attn_metadata.window_attn_bias + self.position_bias = None alibi_blocks = getattr(attn_metadata, 'alibi_blocks', None) if self.alibi_slopes is not None and alibi_blocks is not None: @@ -572,12 +594,12 @@ def forward( output = HPUPagedAttention.forward_decode( query=query, - block_mapping=attn_metadata.block_mapping, - block_bias=attn_metadata.attn_bias, - block_groups=attn_metadata.block_groups, + block_mapping=block_mapping, + block_bias=attn_bias, + block_groups=block_groups, position_bias=self.position_bias, - **self.common_attention_args(attn_metadata.block_list, - key_cache, value_cache, + **self.common_attention_args(block_list, key_cache, + value_cache, attn_metadata.block_size)) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) @@ -768,3 +790,41 @@ def _make_decode_alibi_bias( per_head_bias.mul_(alibi_slopes[None, :, None]) return per_head_bias + + +def _make_sliding_window_bias( + batch_size: int, + seq_len: int, + query_lens_t: torch.tensor, + window_size: int, + dtype: torch.dtype, +) -> torch.Tensor: + + shift = 0 + device = query_lens_t.device + + # TODO: this is not performant as of now. Need to investigate further + # once FusedSDPA kernel with sliding causal mask support is available. + + # causal + sliding window (LEFT PADDING) + tensor = torch.full((batch_size, 1, seq_len, seq_len), + device=device, + dtype=dtype, + fill_value=1) + mask = torch.tril(tensor, diagonal=shift) + mask = torch.triu(mask, diagonal=shift - window_size + 1) + attn_bias = torch.log(mask) + ''' + # TODO Accuracy issue need to be debugged. + # causal + sliding window + query_len (LEFT PADDING : Need kernel supports) + tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device,dtype=dtype,fill_value=1) + mask = torch.tril(tensor, diagonal=shift) + len_mask = torch.arange(0, seq_len, device=device, dtype=torch.int32).view(seq_len,1) + len_mask = len_mask.ge(query_lens_t.unsqueeze(-1)).view(batch_size, 1, seq_len, 1) + len_mask = torch.where(len_mask == False, 1, 0) + mask = mask.logical_and(len_mask) + mask = torch.triu(mask, diagonal=shift - window_size + 1) + attn_bias =torch.where(mask,0, -math.inf) + ''' + + return attn_bias diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 4e0d4f84ca6..255d200556b 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -233,44 +233,34 @@ def naive_attn_with_masks( out: torch.Tensor, **kwargs, ) -> torch.Tensor: - # NOTE(woosuk): As described in the comment above, this code is not - # meant to be performant. It is only meant to be correct. - q = q.view(-1, self.num_heads, self.head_dim) - # Expand the key and value to handle GQA. + + s = q.shape[1] num_queries_per_kv = self.num_heads // self.num_kv_heads - k = k.view(-1, self.num_kv_heads, self.head_dim) - k = k.repeat_interleave(num_queries_per_kv, dim=-2) - v = v.view(-1, self.num_kv_heads, self.head_dim) - v = v.repeat_interleave(num_queries_per_kv, dim=-2) + query = q.view(-1, s, self.num_heads, self.head_dim) + key = k.view(-1, s, self.num_kv_heads, self.head_dim) + key = key.repeat_interleave(num_queries_per_kv, dim=-2) + value = v.view(-1, s, self.num_kv_heads, self.head_dim) + value = value.repeat_interleave(num_queries_per_kv, dim=-2) if self.is_sliding: attn_masks = kwargs["local_attn_masks"] else: attn_masks = kwargs["global_attn_masks"] - seq_lens = kwargs["seq_lens"] - start_idx = 0 - for seq_len, attn_mask in zip(seq_lens, attn_masks): - end_idx = start_idx + seq_len - query = q[start_idx:end_idx].unsqueeze(0) - key = k[start_idx:end_idx].unsqueeze(0) - value = v[start_idx:end_idx].unsqueeze(0) - - # Transpose. - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - output = F.scaled_dot_product_attention( - query, - key, - value, - attn_mask, - self.scaling, - ) - output = output.transpose(1, 2).flatten(-2, -1) - out[start_idx:end_idx] = output - start_idx = end_idx + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + output = F.scaled_dot_product_attention( + query, + key, + value, + attn_masks, + self.scaling, + ) + + out = output.transpose(1, 2).flatten(-2, -1) + return out diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 65c177f8c5a..fc502eea528 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -566,6 +566,7 @@ def _process_image_input( self.vision_tower, pixel_values, ) + image_embeds = self.multi_modal_projector(image_features) return [ @@ -610,6 +611,7 @@ def forward(self, # NOTE: In v1, inputs_embeds is always generated at model runner, this # condition is for v0 compatibility. elif inputs_embeds is None: + assert False, "hpu_model_runner should be computing inputs_embeds" vision_embeddings = self.get_multimodal_embeddings(**kwargs) inputs_embeds = self.get_input_embeddings(input_ids, @@ -639,58 +641,53 @@ def prepare_attn_masks( **kwargs, ): kwargs["has_images"] = True - # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. - # This is a HACK. Fix this. - start_idices = (positions == 0).cpu().nonzero() - num_seqs = len(start_idices) - seq_lens = [] - for i in range(num_seqs): - start_idx = start_idices[i].item() - if i < num_seqs - 1: - end_idx = start_idices[i + 1].item() - else: - end_idx = len(input_ids) - seq_lens.append(end_idx - start_idx) - kwargs["seq_lens"] = seq_lens - - global_attn_masks = [] - local_attn_masks = [] - start_idx = 0 - for seq_len in seq_lens: - end_idx = start_idx + seq_len - input_token_ids = input_ids[start_idx:end_idx] - start_idx = end_idx - # Create a global causal mask. - global_attn_mask = torch.empty( - 1, - 1, - seq_len, - seq_len, - dtype=mask_dtype, - device=input_ids.device, - ) - global_attn_mask.fill_(float("-inf")) - # Fill the lower triangle with 0. - global_attn_mask = global_attn_mask.triu(diagonal=1) - - # Consider the bidirectional attention between image tokens. - img_mask = torch.zeros_like(global_attn_mask) - img_pos = (input_token_ids == self.config.image_token_index) - img_mask[:, :, :, img_pos] += 1 - img_mask[:, :, img_pos, :] += 1 - global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) - global_attn_masks.append(global_attn_mask) - - if self.sliding_window is not None: - # Create a local causal mask with sliding window (1024). - local_attn_mask = torch.ones_like(global_attn_mask) - local_attn_mask = torch.tril(local_attn_mask, - diagonal=-self.sliding_window) - local_attn_mask = torch.where(local_attn_mask == 0, - global_attn_mask, float("-inf")) - local_attn_masks.append(local_attn_mask) - kwargs["global_attn_masks"] = global_attn_masks - kwargs["local_attn_masks"] = local_attn_masks + IMG_TOKENS = 256 + seq_len = input_ids.shape[1] + bs = input_ids.shape[0] + kwargs["seq_lens"] = [seq_len] * bs + + global_attn_mask = torch.empty( + bs, + 1, + seq_len, + seq_len, + dtype=mask_dtype, + device=input_ids.device, + ) + global_attn_mask.fill_(float("-inf")) + global_attn_mask = global_attn_mask.triu(diagonal=1) + + img_mask = torch.zeros_like(global_attn_mask) + img_pos = (input_ids == self.config.image_token_index) + + img_mask[img_pos.unsqueeze(1)] += 1 + img_mask = img_mask.permute(0, 1, 3, 2) + img_mask[img_pos.unsqueeze(1)] += 1 + img_mask = img_mask.permute(0, 1, 3, 2) + + img_pos_cum = torch.cumsum(img_pos, 1) + img_causal = torch.arange(seq_len, device=input_ids.device).unsqueeze( + 0) - img_pos_cum + (img_pos_cum // IMG_TOKENS + 1) * IMG_TOKENS + 1 + img_causal = torch.cat((img_causal[:, 0:1] - 1, img_causal[:, :-1]), + dim=1) + img_causal = img_causal.clamp_(min=0, max=seq_len - + 1).unsqueeze(1).unsqueeze(3) + ind = torch.arange( + seq_len, + device=input_ids.device).unsqueeze(0).unsqueeze(1).unsqueeze(2) + img_mask[ind < img_causal] += 1 + global_attn_mask = torch.where(img_mask == 3, 0, global_attn_mask) + + if self.sliding_window is not None: + # Create a local causal mask with sliding window (1024). + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, + diagonal=-self.sliding_window) + local_attn_mask = torch.where(local_attn_mask == 0, + global_attn_mask, float("-inf")) + + kwargs["global_attn_masks"] = global_attn_mask + kwargs["local_attn_masks"] = local_attn_mask return kwargs def compute_logits( diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 0e97e93a3e0..2358aa55332 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -394,6 +394,8 @@ def _merge_multimodal_embeddings( if current_platform.is_hpu(): htcore.mark_step() flattened = _flatten_embeddings(multimodal_embeddings) + #TODO dynamic.. maybe torch.where? however multimodal_embeddings is a list of varying length + # still.. torch.where migth be faster than boolean indexing? inputs_embeds[is_multimodal] = flattened return inputs_embeds diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index b3b6bb4749d..731934bff2f 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -129,6 +129,10 @@ def __call__(cls, *args, **kwargs): return cls._instances[cls] +def is_gemma3(model): + return 'Gemma3ForConditionalGeneration' in str(type(model)) + + def pad_flat_tensor(tensor, desired_size): assert tensor.dim() == 1, 'Only flat tensors are supported' padding_needed = desired_size - tensor.size(0) @@ -354,6 +358,51 @@ def __init__(self, model, vllm_config, layer_names, is_causal, sampler): self.model.visual = htorch.hpu.wrap_in_hpu_graph( self.model.visual, disable_tensor_cache=True) + # TODO : right now just enabling it keeping gemma3 in mind + if htorch.utils.internal.is_lazy() and is_gemma3(self.model): + logger.info("[Multimodal] Wrapping Visual Model") + self.model.vision_tower = htorch.hpu.wrap_in_hpu_graph( + self.model.vision_tower, disable_tensor_cache=True) + self.model.multi_modal_projector = htorch.hpu.wrap_in_hpu_graph( + self.model.multi_modal_projector, disable_tensor_cache=True) + + # copying from PR 1163 + # needs cleanup/unified approach later + def compute_input_embeddings_for_gemma(self, **kwargs): + + if 'inputs_embeds' in kwargs: + print('do nothing') + return kwargs + + # todo may or may not be needed for gemma3, check + compile_only_mode_context_false = functools.partial( + bc.env_setting, "PT_COMPILE_ONLY_MODE", False) + + input_ids = kwargs['input_ids'] + # + #with compile_only_mode_context_false(): + vision_embeddings = self.model.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.model.get_input_embeddings( + input_ids, vision_embeddings) + + if vision_embeddings is not None: + print('vision_embeddings is not None') + #breakpoint() + input_ids = kwargs['input_ids'] + positions = kwargs['positions'] + kwargs = self.model.prepare_attn_masks( + mask_dtype=self.dtype, + **kwargs, + ) + kwargs['input_ids'] = input_ids + kwargs['positions'] = positions + #input_ids = None + + kwargs.update({'inputs_embeds': inputs_embeds}) + # done compute the visual tokens + kwargs.pop('pixel_values', None) + return kwargs + def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): if (attn_metadata is None @@ -414,37 +463,55 @@ def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, attn_bias=attn_bias) return attn_metadata - def _set_block_mapping(self, metadata, batch_size, device, dtype): + def _set_block_mapping(self, metadata, batch_size, device, dtype, + is_window_block): + + block_usage = metadata.block_usage if not is_window_block else metadata.window_block_usage + block_groups = metadata.block_groups if not is_window_block else metadata.window_block_groups + mask = torch.arange(0, self.block_size, device=device, dtype=torch.int32).unsqueeze(0) - mask = mask >= metadata.block_usage.unsqueeze(-1) + mask = mask >= block_usage.unsqueeze(-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) if not is_fake_hpu(): - block_mapping = torch.nn.functional.one_hot(metadata.block_groups, + block_mapping = torch.nn.functional.one_hot(block_groups, num_classes=batch_size) else: # Unfortunately one_hot on CPU # doesn't handle out of bounds classes so we need to convert # all negative values to 0 (block_mapping) or bs (block_groups) - block_groups = metadata.block_groups.to(torch.long) + block_groups = block_groups.to(torch.long) block_mapping = torch.nn.functional.relu(block_groups) block_mapping = torch.nn.functional.one_hot(block_mapping, num_classes=batch_size) oob_values = block_groups.lt(0) block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) block_groups.masked_fill_(oob_values, batch_size) + if not is_window_block: + metadata = custom_tuple_replace(metadata, + "TrimmedAttentionMetadata", + block_groups=block_groups) + else: + metadata = custom_tuple_replace( + metadata, + "TrimmedAttentionMetadata", + window_block_groups=block_groups) + + block_mapping = block_mapping.to(dtype) + if not is_window_block: metadata = custom_tuple_replace(metadata, "TrimmedAttentionMetadata", - block_groups=block_groups) - block_mapping = block_mapping.to(dtype) - metadata = custom_tuple_replace(metadata, - "TrimmedAttentionMetadata", - block_mapping=block_mapping, - attn_bias=attn_bias) + block_mapping=block_mapping, + attn_bias=attn_bias) + else: + metadata = custom_tuple_replace(metadata, + "TrimmedAttentionMetadata", + window_block_mapping=block_mapping, + window_attn_bias=attn_bias) return metadata def forward_update_meta_only(self, *args, **kwargs): @@ -467,7 +534,10 @@ def _update_metadata(self, attn_metadata, batch_size, seq_len, device, seq_len, device, dtype) else: attn_metadata = self._set_block_mapping(attn_metadata, batch_size, - device, dtype) + device, dtype, False) + if attn_metadata.window_block_list is not None: + attn_metadata = self._set_block_mapping(attn_metadata, batch_size, + device, dtype, True) return attn_metadata def _prepare_cos_sin(self, positions): @@ -537,6 +607,7 @@ def forward(self, *args, **kwargs): virtual_engine = 0 if 'virtual_engine' in kwargs: virtual_engine = kwargs.pop('virtual_engine') + input_ids = kwargs['input_ids'] kwargs['attn_metadata'] = self._update_metadata( kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), @@ -798,6 +869,11 @@ def __init__( self.sliding_window = (self.model_config.get_sliding_window() if self.model_config is not None else None) + + self.interleaved_sliding_window = getattr( + self.model_config.hf_text_config, "interleaved_sliding_window", + None) + self.device_config = (self.device_config if self.device_config is not None else DeviceConfig()) if is_fake_hpu(): @@ -1304,6 +1380,17 @@ def add_vision_buckets_to_mrope_models(self): model = self.get_model() model.vision_buckets = VisionBuckets() + def _get_position_pad(self) -> int: + """ + For gemma3 models, + due to the Hack in Gemma3ForConditionalGeneration::prepare_attn_masks, + '0' can't be used as pad for input position tensor. + In case, it might have '0's for bucketing, those '0' will be counted as + new sequence in the prepare_attn_masks() which is wrong. + """ + model_type = getattr(self.model_config.hf_config, 'model_type', '') + return -1 if model_type == 'gemma3' else 0 + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -1550,11 +1637,11 @@ def _prepare_prompt( make_mrope_positions_tensor_with_pad(input_positions=input_positions, input_mrope_positions=input_mrope_positions, max_prompt_len=max_prompt_len, - pad=0) + pad=self._get_position_pad()) else: input_positions = make_cpu_tensor(input_positions, max_len=max_prompt_len, - pad=0, + pad=self._get_position_pad(), dtype=torch.long, flat=self.use_merged_prefill) @@ -1673,6 +1760,7 @@ def _prepare_decode( encoder_seq_lens: List[int] = [] cross_block_tables: List[List[int]] = [] block_tables: List[List[int]] = [] + window_block_tables: List[List[int]] = [] lora_index_mapping: List[List[int]] = [] lora_prompt_mapping: List[List[int]] = [] lora_requests: Set[LoRARequest] = set() @@ -1727,6 +1815,7 @@ def _prepare_decode( for idx in range(3): input_mrope_positions[idx].extend(pos_for_mrope[idx]) + #logger.info(f"Decode: seq_len:{seq_len}, sliding_window{self.sliding_window}") seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) seq_lens.append(seq_len) @@ -1754,6 +1843,12 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + if self.interleaved_sliding_window is not None: + sliding_window_blocks = (self.interleaved_sliding_window // + self.block_size) + window_block_table = block_table[-sliding_window_blocks:] + window_block_tables.append(window_block_table) + if output is None: input_tokens = torch.tensor(input_tokens, dtype=torch.long, @@ -1784,6 +1879,25 @@ def _prepare_decode( assert len(block_list) == len(block_groups) assert len(block_list) == len(block_usage) + if self.interleaved_sliding_window is not None: + window_block_groups = [[i] * len(bt) + for i, bt in enumerate(window_block_tables)] + window_block_usage = [ + [self.block_size] * (len(bt) - 1) + [lbu] + for bt, lbu in zip(block_tables, last_block_usage) if bt + ] + + window_block_list = flatten(window_block_tables) + window_block_groups = flatten(window_block_groups) + window_block_usage = flatten(window_block_usage) + + assert len(window_block_list) == len(window_block_groups) + assert len(window_block_list) == len(window_block_list) + else: + window_block_list = None + window_block_groups = None + window_block_usage = None + if is_enc_dec_model: last_cross_block_usage = [ (encoder_seq_len - 1) % self.block_size + 1 @@ -1826,6 +1940,14 @@ def _prepare_decode( indices[bid] = i padding_fn = lambda tensor, pad_value: gather_list( tensor, indices, pad_value) + if self.interleaved_sliding_window is not None: + window_indices: List[Any] + window_indices = [None] * block_bucket_size + for i, bid in enumerate(window_block_list): + window_indices[bid] = i + window_padding_fn = lambda tensor, pad_value: gather_list( + tensor, window_indices, pad_value) + else: block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( len(block_list)) @@ -1843,6 +1965,17 @@ def _prepare_decode( block_groups = padding_fn(block_groups, -1) block_usage = padding_fn(block_usage, 1) + if self.interleaved_sliding_window is not None: + window_block_list = window_padding_fn(window_block_list, + _PAD_BLOCK_ID) + window_block_groups = window_padding_fn(window_block_groups, -1) + #window_block_usage = window_padding_fn(window_block_usage, 1) + window_block_usage = [ + [1] if i == 0 else [block_usage[idx]] + for idx, (i, + j) in enumerate(zip(window_block_list, block_usage)) + ] + if is_enc_dec_model: if self.use_contiguous_pa: cross_block_bucket_size = max( @@ -1934,6 +2067,24 @@ def _prepare_decode( encoder_seq_lens_tensor = encoder_seq_lens_tensor.to( # type: ignore self.device, non_blocking=True) + if self.interleaved_sliding_window is not None: + window_block_list = torch.tensor(window_block_list, + dtype=torch.int, + device='cpu') + window_block_groups = torch.tensor(window_block_groups, + dtype=torch.int, + device='cpu') + window_block_usage = torch.tensor(window_block_usage, + dtype=self.model_config.dtype, + device='cpu') + + window_block_list = window_block_list.to( # type: ignore + self.device, non_blocking=True) + window_block_groups = window_block_groups.to( # type: ignore + self.device, non_blocking=True) + window_block_usage = window_block_usage.to( # type: ignore + self.device, non_blocking=True) + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_size=self.block_size, @@ -1941,6 +2092,10 @@ def _prepare_decode( block_mapping=None, block_usage=block_usage, block_groups=block_groups, + window_block_list=window_block_list, + window_block_mapping=None, + window_block_usage=window_block_usage, + window_block_groups=window_block_groups, attn_bias=None, seq_lens_tensor=None, encoder_seq_lens=encoder_seq_lens, @@ -2337,6 +2492,11 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: 'block_groups', 'input_positions', 'alibi_blocks', + 'window_block_list', + 'window_block_mapping', + 'window_block_usage', + 'window_block_groups', + 'window_attn_bias', ]) return attention_metadata @@ -3482,6 +3642,12 @@ def try_revert_dummy_output_tokens(): } if not bypass_model_exec: + if is_gemma3(self.model.model): + execute_model_kwargs = \ + self.model.compute_input_embeddings_for_gemma( + **execute_model_kwargs + ) + with self.profiler.record_event('internal', model_event_name, args=profiler_args):