Skip to content

Enable interleaved sliding_window for gemma3 #1344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e92d432
Fix gemma3 workload execution failure
shepark Apr 24, 2025
bd75109
add run scripts
ssarkar2 May 21, 2025
90a3ae1
minor
ssarkar2 May 23, 2025
9434197
Update sliding_window attention
jiminha May 28, 2025
0364786
Update run scripts
ssarkar2 May 28, 2025
8caf102
Update sliding_window mask logic for lazy mode
jiminha May 28, 2025
bff7983
Fix long prompt accuracy issue
jiminha May 30, 2025
dccd67e
Change back to Eager mode for Vision prompt
jiminha May 30, 2025
45aaede
Remove unnecessary files
jiminha May 30, 2025
7df1811
Remove test file
jiminha May 30, 2025
e4b0397
Merge branch 'habana_main' into jha/sliding_window_gemma3
ssarkar2 May 30, 2025
6088039
Enable bs>1
ssarkar2 Jun 3, 2025
8b13980
enable hpu graph model
maktukmak Jun 4, 2025
a9e5a7d
Add temporary test scripts
ssarkar2 Jun 4, 2025
f783955
Fix for missing image
ssarkar2 Jun 5, 2025
1297154
Bring back +1
ssarkar2 Jun 5, 2025
be41114
Switch to lazy+hpugraphs, add v0 mode
ssarkar2 Jun 5, 2025
74e4cfb
Fix masks. Remove cross attn between images
ssarkar2 Jun 6, 2025
347e965
Script for variable batches
ssarkar2 Jun 6, 2025
a29d537
Do vision+combining before text mdoel fwd
ssarkar2 Jun 10, 2025
c9c5757
wrap vision and projector in hpu graphs
ssarkar2 Jun 10, 2025
5af6870
vectorized mask generation
maktukmak Jun 10, 2025
000b4e0
Revert "wrap vision and projector in hpu graphs"
ssarkar2 Jun 11, 2025
658442d
Revert "Do vision+combining before text mdoel fwd"
ssarkar2 Jun 11, 2025
61a3e2f
Fixing the earlier commit which was reverted
ssarkar2 Jun 11, 2025
39e0f52
bring back reverted commit
ssarkar2 Jun 11, 2025
661f59a
Fix accuracy issue with repeat words for long prompts
jiminha Jun 21, 2025
affc7a7
Change parameter check for intereleaved sliding_window
jiminha Jun 23, 2025
994de89
Remove all test files
jiminha Jun 23, 2025
ad492f8
Merge remote-tracking branch 'origin/habana_main' into jha/sliding_wi…
jiminha Jun 23, 2025
1ceca57
Merge branch 'habana_main' into jha/sliding_window_gemma3
jiminha Jun 23, 2025
805df55
Fix error from merge
jiminha Jun 23, 2025
d531412
Fix pre-commit errors
jiminha Jun 23, 2025
f99d76a
Pre-commit fix for the list warning
jiminha Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 68 additions & 8 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@
cross_block_groups: Optional[torch.Tensor] = None
cross_block_usage: Optional[torch.Tensor] = None
cross_attn_bias: Optional[torch.Tensor] = None
window_block_list: Optional[torch.Tensor] = None
window_slot_mapping: Optional[torch.Tensor] = None
window_block_mapping: Optional[torch.Tensor] = None
window_block_groups: Optional[torch.Tensor] = None
window_block_usage: Optional[torch.Tensor] = None
window_attn_bias: Optional[torch.Tensor] = None


@dataclass
Expand Down Expand Up @@ -542,6 +548,18 @@
block_list = attn_metadata.block_list if attn_metadata \
and attn_metadata.block_list is not None else None

common_args = self.common_attention_args(block_list, key_cache,
value_cache,
attn_metadata.block_size)

#TODO: Ideally we want to create this sliding_window_bias mask only
#once in the model_runner or gemma model file then only retrieve here.

Check failure on line 556 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/hpu_attn.py:556:81: E501 Line too long (82 > 80)
if self.sliding_window:
attn_bias = _make_sliding_window_bias(
batch_size, seq_len, attn_metadata.seq_lens_tensor,
self.sliding_window, query.dtype)
common_args['pad'] = 'left'

out = ops.prompt_attention(
impl=self.prefill_impl,
query=query.view(query_shape),
Expand All @@ -551,12 +569,16 @@
attn_bias=attn_bias,
position_bias=position_bias,
valid_seq_lengths=attn_metadata.seq_lens_tensor,
**self.common_attention_args(block_list, key_cache,
value_cache,
attn_metadata.block_size))
**common_args)

output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
block_list = attn_metadata.block_list if not self.sliding_window else attn_metadata.window_block_list

Check failure on line 577 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/hpu_attn.py:577:81: E501 Line too long (113 > 80)
block_groups = attn_metadata.block_groups if not self.sliding_window else attn_metadata.window_block_groups

Check failure on line 578 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/hpu_attn.py:578:81: E501 Line too long (119 > 80)
block_mapping = attn_metadata.block_mapping if not self.sliding_window else attn_metadata.window_block_mapping

Check failure on line 579 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/hpu_attn.py:579:81: E501 Line too long (122 > 80)
attn_bias = attn_metadata.attn_bias if not self.sliding_window else attn_metadata.window_attn_bias

Check failure on line 580 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/hpu_attn.py:580:81: E501 Line too long (110 > 80)

self.position_bias = None
alibi_blocks = getattr(attn_metadata, 'alibi_blocks', None)
if self.alibi_slopes is not None and alibi_blocks is not None:
Expand All @@ -572,12 +594,12 @@

output = HPUPagedAttention.forward_decode(
query=query,
block_mapping=attn_metadata.block_mapping,
block_bias=attn_metadata.attn_bias,
block_groups=attn_metadata.block_groups,
block_mapping=block_mapping,
block_bias=attn_bias,
block_groups=block_groups,
position_bias=self.position_bias,
**self.common_attention_args(attn_metadata.block_list,
key_cache, value_cache,
**self.common_attention_args(block_list, key_cache,
value_cache,
attn_metadata.block_size))
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
Expand Down Expand Up @@ -768,3 +790,41 @@
per_head_bias.mul_(alibi_slopes[None, :, None])

return per_head_bias


def _make_sliding_window_bias(
batch_size: int,
seq_len: int,
query_lens_t: torch.tensor,
window_size: int,
dtype: torch.dtype,
) -> torch.Tensor:

shift = 0
device = query_lens_t.device

# TODO: this is not performant as of now. Need to investigate further
# once FusedSDPA kernel with sliding causal mask support is available.

# causal + sliding window (LEFT PADDING)
tensor = torch.full((batch_size, 1, seq_len, seq_len),
device=device,
dtype=dtype,
fill_value=1)
mask = torch.tril(tensor, diagonal=shift)
mask = torch.triu(mask, diagonal=shift - window_size + 1)
attn_bias = torch.log(mask)
'''
# TODO Accuracy issue need to be debugged.
# causal + sliding window + query_len (LEFT PADDING : Need kernel supports)
tensor = torch.full((batch_size, 1, seq_len, seq_len), device=device,dtype=dtype,fill_value=1)

Check failure on line 820 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/hpu_attn.py:820:81: E501 Line too long (98 > 80)
mask = torch.tril(tensor, diagonal=shift)
len_mask = torch.arange(0, seq_len, device=device, dtype=torch.int32).view(seq_len,1)

Check failure on line 822 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/hpu_attn.py:822:81: E501 Line too long (89 > 80)
len_mask = len_mask.ge(query_lens_t.unsqueeze(-1)).view(batch_size, 1, seq_len, 1)

Check failure on line 823 in vllm/attention/backends/hpu_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/backends/hpu_attn.py:823:81: E501 Line too long (86 > 80)
len_mask = torch.where(len_mask == False, 1, 0)
mask = mask.logical_and(len_mask)
mask = torch.triu(mask, diagonal=shift - window_size + 1)
attn_bias =torch.where(mask,0, -math.inf)
'''

return attn_bias
52 changes: 21 additions & 31 deletions vllm/model_executor/models/gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,44 +233,34 @@ def naive_attn_with_masks(
out: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# NOTE(woosuk): As described in the comment above, this code is not
# meant to be performant. It is only meant to be correct.
q = q.view(-1, self.num_heads, self.head_dim)
# Expand the key and value to handle GQA.

s = q.shape[1]
num_queries_per_kv = self.num_heads // self.num_kv_heads
k = k.view(-1, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
v = v.view(-1, self.num_kv_heads, self.head_dim)
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
query = q.view(-1, s, self.num_heads, self.head_dim)
key = k.view(-1, s, self.num_kv_heads, self.head_dim)
key = key.repeat_interleave(num_queries_per_kv, dim=-2)
value = v.view(-1, s, self.num_kv_heads, self.head_dim)
value = value.repeat_interleave(num_queries_per_kv, dim=-2)

if self.is_sliding:
attn_masks = kwargs["local_attn_masks"]
else:
attn_masks = kwargs["global_attn_masks"]

seq_lens = kwargs["seq_lens"]
start_idx = 0
for seq_len, attn_mask in zip(seq_lens, attn_masks):
end_idx = start_idx + seq_len
query = q[start_idx:end_idx].unsqueeze(0)
key = k[start_idx:end_idx].unsqueeze(0)
value = v[start_idx:end_idx].unsqueeze(0)

# Transpose.
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask,
self.scaling,
)
output = output.transpose(1, 2).flatten(-2, -1)
out[start_idx:end_idx] = output
start_idx = end_idx
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

output = F.scaled_dot_product_attention(
query,
key,
value,
attn_masks,
self.scaling,
)

out = output.transpose(1, 2).flatten(-2, -1)

return out


Expand Down
101 changes: 49 additions & 52 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@
self.vision_tower,
pixel_values,
)

image_embeds = self.multi_modal_projector(image_features)

return [
Expand Down Expand Up @@ -610,6 +611,7 @@
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
assert False, "hpu_model_runner should be computing inputs_embeds"

Check failure on line 614 in vllm/model_executor/models/gemma3_mm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (B011)

vllm/model_executor/models/gemma3_mm.py:614:20: B011 Do not `assert False` (`python -O` removes these calls), raise `AssertionError()`
vision_embeddings = self.get_multimodal_embeddings(**kwargs)

inputs_embeds = self.get_input_embeddings(input_ids,
Expand Down Expand Up @@ -639,58 +641,53 @@
**kwargs,
):
kwargs["has_images"] = True
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_idices = (positions == 0).cpu().nonzero()
num_seqs = len(start_idices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_idices[i].item()
if i < num_seqs - 1:
end_idx = start_idices[i + 1].item()
else:
end_idx = len(input_ids)
seq_lens.append(end_idx - start_idx)
kwargs["seq_lens"] = seq_lens

global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_len in seq_lens:
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
start_idx = end_idx
# Create a global causal mask.
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0.
global_attn_mask = global_attn_mask.triu(diagonal=1)

# Consider the bidirectional attention between image tokens.
img_mask = torch.zeros_like(global_attn_mask)
img_pos = (input_token_ids == self.config.image_token_index)
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)

if self.sliding_window is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
IMG_TOKENS = 256
seq_len = input_ids.shape[1]
bs = input_ids.shape[0]
kwargs["seq_lens"] = [seq_len] * bs

global_attn_mask = torch.empty(
bs,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
global_attn_mask = global_attn_mask.triu(diagonal=1)

img_mask = torch.zeros_like(global_attn_mask)
img_pos = (input_ids == self.config.image_token_index)

img_mask[img_pos.unsqueeze(1)] += 1
img_mask = img_mask.permute(0, 1, 3, 2)
img_mask[img_pos.unsqueeze(1)] += 1
img_mask = img_mask.permute(0, 1, 3, 2)

img_pos_cum = torch.cumsum(img_pos, 1)
img_causal = torch.arange(seq_len, device=input_ids.device).unsqueeze(
0) - img_pos_cum + (img_pos_cum // IMG_TOKENS + 1) * IMG_TOKENS + 1
img_causal = torch.cat((img_causal[:, 0:1] - 1, img_causal[:, :-1]),
dim=1)
img_causal = img_causal.clamp_(min=0, max=seq_len -
1).unsqueeze(1).unsqueeze(3)
ind = torch.arange(
seq_len,
device=input_ids.device).unsqueeze(0).unsqueeze(1).unsqueeze(2)
img_mask[ind < img_causal] += 1
global_attn_mask = torch.where(img_mask == 3, 0, global_attn_mask)

if self.sliding_window is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))

kwargs["global_attn_masks"] = global_attn_mask
kwargs["local_attn_masks"] = local_attn_mask
return kwargs

def compute_logits(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@
if current_platform.is_hpu():
htcore.mark_step()
flattened = _flatten_embeddings(multimodal_embeddings)
#TODO dynamic.. maybe torch.where? however multimodal_embeddings is a list of varying length

Check failure on line 397 in vllm/model_executor/models/utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/models/utils.py:397:81: E501 Line too long (100 > 80)
# still.. torch.where migth be faster than boolean indexing?
inputs_embeds[is_multimodal] = flattened
return inputs_embeds

Expand Down
Loading