Skip to content

Commit 643e6f5

Browse files
authored
[Bugfix] Fix accuracy problem caused by mask pollution (#1678)
### What this PR does / why we need it? If a small batch of short requests is sent first, forming a chunk with a length <128, it will corrupt the `attn_mask_cache`, causing subsequent requests that do not form a chunk to have accuracy issues. The root cause of this problem is the use of in-place multiplication. Modifying it to use out-of-place multiplication will resolve the accuracy problem. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Yes. - vLLM version: v0.9.2 - vLLM main: vllm-project/vllm@ad6c2e1 --------- Signed-off-by: ApsarasX <apsarax@outlook.com>
1 parent 60519c7 commit 643e6f5

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

tests/ut/attention/test_attention_mask.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,52 @@ def test_get_splitfuse_attn_mask(self):
105105
device=torch.device("cpu"),
106106
)
107107
self.assertEqual(attn_mask.shape, (1, 512))
108+
109+
def test_use_multiple_masks(self):
110+
max_seq_lens = [128, 512, 1024]
111+
dtypes = [torch.float16, torch.bfloat16, torch.int8]
112+
for max_seq_len, dtype in zip(max_seq_lens, dtypes):
113+
with self.subTest(max_seq_len=max_seq_len, dtype=dtype):
114+
self._test_use_multiple_masks(max_seq_len, dtype)
115+
116+
def _test_use_multiple_masks(self, max_seq_len, dtype):
117+
expected_mask_value = torch.finfo(
118+
torch.float32).min if dtype == torch.float16 else 1
119+
if dtype == torch.float16:
120+
expected_splitfuse_mask_value = expected_mask_value
121+
elif dtype == torch.bfloat16:
122+
expected_splitfuse_mask_value = -10000
123+
else:
124+
assert dtype == torch.int8, "Unsupported dtype for attention mask"
125+
expected_splitfuse_mask_value = -16
126+
127+
attention_mask_builder = AttentionMaskBuilder(max_seq_len=max_seq_len,
128+
dtype=dtype)
129+
130+
splitfuse_attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
131+
seq_lens=[max_seq_len],
132+
query_lens=[max_seq_len],
133+
position=torch.tensor([0]),
134+
dtype=dtype,
135+
device=torch.device("cpu"),
136+
)
137+
self.assertEqual(splitfuse_attn_mask.shape, (1, max_seq_len))
138+
self.assertEqual(
139+
splitfuse_attn_mask[0][-1],
140+
torch.tensor(expected_splitfuse_mask_value, dtype=dtype))
141+
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
142+
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
143+
(max_seq_len, max_seq_len))
144+
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
145+
torch.tensor(expected_mask_value, dtype=dtype))
146+
147+
attn_mask = attention_mask_builder.get_attn_mask(
148+
max_seq_len=max_seq_len, dtype=dtype, device=torch.device("cpu"))
149+
self.assertEqual(attn_mask.shape, (max_seq_len, max_seq_len))
150+
self.assertEqual(attn_mask[0][-1],
151+
torch.tensor(expected_mask_value, dtype=dtype))
152+
self.assertEqual(attention_mask_builder._seq_len_cached, max_seq_len)
153+
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
154+
(max_seq_len, max_seq_len))
155+
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
156+
torch.tensor(expected_mask_value, dtype=dtype))

vllm_ascend/attention/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,8 @@ def build(
572572
attn_mask = AscendMetadataBuilder._attn_mask_builder.get_attn_mask( # type: ignore
573573
max_seq_len, dtype, device)
574574
if attn_mask.numel() > 1 and attn_mask[0][1] > 0:
575-
attn_mask *= -10000
575+
# Do not use in-place multiplication to avoid modifying `attn_mask_cache`!
576+
attn_mask = attn_mask * -10000
576577
chunk_mask_list = []
577578
for i, seq_len in enumerate(seq_lens):
578579
context_len = self.context_lens[i]

vllm_ascend/attention/attention_mask.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def get_splitfuse_attn_mask(
6868
) > 1 and self.attn_mask_cache[0][1] > 0:
6969
attn_mask = self.get_attn_mask( # type: ignore
7070
max_seq_len, dtype, device)
71-
attn_mask *= -10000
71+
# Do not use in-place multiplication to avoid modifying `self.attn_mask_cache`!
72+
attn_mask = attn_mask * -10000
7273
else:
7374
attn_mask = self.attn_mask_cache
7475
return torch.index_select(attn_mask, dim=0,

0 commit comments

Comments
 (0)