Skip to content

Commit 028b4d7

Browse files
drisspgjimpang
authored andcommitted
Fixes IMA for TP w/ flex-attention (vllm-project#19712)
Signed-off-by: drisspg <drisspguessous@gmail.com>
1 parent 8af5a85 commit 028b4d7

File tree

2 files changed

+2
-10
lines changed

2 files changed

+2
-10
lines changed

tests/kernels/test_flex_attention.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
5151
with monkeypatch.context() as m:
5252
m.setenv("VLLM_USE_V1", "1")
5353
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
54-
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
5554

5655
set_seed(seed)
5756

@@ -66,7 +65,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
6665
# Run with default backend
6766
with monkeypatch.context() as m:
6867
m.setenv("VLLM_USE_V1", "1")
69-
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
7068
set_seed(seed)
7169
llm_default = LLM(
7270
model_name,

vllm/v1/attention/backends/flex_attention.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1414
AttentionMetadata, AttentionType,
1515
is_quantized_kv_cache)
16-
from vllm.distributed import get_tensor_model_parallel_world_size
1716
from vllm.logger import init_logger
1817
from vllm.platforms import current_platform
1918
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
@@ -237,17 +236,13 @@ def final_mask_mod(
237236

238237
def build_block_mask(self) -> BlockMask:
239238
assert self.mask_mod is not None
240-
# FIXME: With TP>1, create_block_mask_compiled will raise
241-
# CUDA error: an illegal memory access was encountered
242-
create_block_mask_fn = (create_block_mask_compiled
243-
if get_tensor_model_parallel_world_size() == 1
244-
else create_block_mask)
245-
return create_block_mask_fn(
239+
return create_block_mask_compiled(
246240
self.mask_mod,
247241
None,
248242
None,
249243
self.num_actual_tokens,
250244
self.total_cache_tokens,
245+
device=self.block_table.device,
251246
)
252247

253248
def __post_init__(self):
@@ -429,7 +424,6 @@ def forward(
429424
shape = [num_tokens, num_heads * head_size]
430425
"""
431426
assert output is not None, "Output tensor must be provided."
432-
433427
if output_scale is not None:
434428
raise NotImplementedError(
435429
"fused output quantization is not yet supported"

0 commit comments

Comments
 (0)