Skip to content

Commit c5b41dc

Browse files
committed
Missing parameter for sdpa
Default value for the layout Missing boundary check
1 parent cfda5b3 commit c5b41dc

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ def forward(
814814
self.num_heads,
815815
self.head_size,
816816
self.scale,
817+
causal_mask,
817818
attn_masks,
818819
)
819820
else:

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,9 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
234234
return topk_weights, topk_indices
235235

236236

237-
def shuffle_weights(*tensors: torch.Tensor,
238-
layout: tuple[int, int]) -> tuple[torch.Tensor, ...]:
237+
def shuffle_weights(
238+
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
239+
) -> tuple[torch.Tensor, ...]:
239240
"""
240241
Applies shuffle_weight function from AITER to each
241242
input tensor and returns them.

vllm/model_executor/layers/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def rocm_unquantized_gemm(x: torch.Tensor,
8383
m = weight.shape[0]
8484
cu_count = current_platform.get_cu_count()
8585

86-
if m > 8 and n < 4:
86+
if m > 8 and 0 < n < 4:
8787
out = ops.wvSplitK(weight, x_view, cu_count)
8888
return out.view(*x.shape[:-1], weight.shape[0])
8989
elif m % 4 == 0 and n == 1 and k <= 8192:

0 commit comments

Comments
 (0)