Skip to content

Commit 69d54e2

Browse files
get triton tests to pass
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent fcf4be1 commit 69d54e2

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,10 @@ def create_and_prepopulate_kv_cache(
187187
class MockAttentionLayer:
188188
"""A mock attention layer for testing."""
189189

190-
def __init__(self):
191-
self._q_scale = torch.tensor(1.0)
192-
self._k_scale = torch.tensor(1.0)
193-
self._v_scale = torch.tensor(1.0)
190+
def __init__(self, device: torch.device):
191+
self._q_scale = torch.tensor(1.0, device=device)
192+
self._k_scale = torch.tensor(1.0, device=device)
193+
self._v_scale = torch.tensor(1.0, device=device)
194194
# Add float versions for flashinfer
195195
self._k_scale_float = 1.0
196196
self._v_scale_float = 1.0
@@ -258,7 +258,7 @@ def mock_get_per_layer_parameters(vllm_config):
258258
)
259259

260260
# Create mock layer and output buffer
261-
mock_layer = MockAttentionLayer()
261+
mock_layer = MockAttentionLayer(device)
262262
output = torch.empty_like(query)
263263

264264
# Run forward pass

tests/v1/attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_attention_backend(backend_name: _Backend):
114114
_Backend.FLEX_ATTENTION:
115115
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend",
116116
_Backend.TRITON_ATTN_VLLM_V1:
117-
"vllm.v1.attention.backends.triton_attn.TritonAttnBackend",
117+
"vllm.v1.attention.backends.triton_attn.TritonAttentionBackend",
118118
}
119119

120120
if backend_name not in backend_map:

0 commit comments

Comments
 (0)