Skip to content

Commit 97a0908

Browse files
committed
Address lint
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
1 parent e989be5 commit 97a0908

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

tests/singlecard/ops/test_rotary_embedding.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def test_rotary_embedding_quant_with_leading_dim(
198198
atol=DEFAULT_ATOL,
199199
rtol=DEFAULT_RTOL)
200200

201+
201202
# test rope with npu_mrope interface with leading dimension and merge seqlen and batch_size as num_tokens
202203
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
203204
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@@ -244,14 +245,13 @@ def test_npu_mrope_quant_with_leading_dim(
244245
ref_query, ref_key = rope.forward_native(positions, query, key)
245246

246247
query, key = torch_npu.npu_mrope(
247-
positions,
248-
query,
249-
key,
250-
rope.cos_sin_cache,
251-
rope.head_size,
252-
mrope_section=[0,0,0],
253-
rotary_mode='half' if rope.is_neox_style else 'interleave'
254-
)
248+
positions,
249+
query,
250+
key,
251+
rope.cos_sin_cache,
252+
rope.head_size,
253+
mrope_section=[0, 0, 0],
254+
rotary_mode='half' if rope.is_neox_style else 'interleave')
255255

256256
# Compare the results.
257257
torch.testing.assert_close(query.view(ref_query.size()),
@@ -261,4 +261,4 @@ def test_npu_mrope_quant_with_leading_dim(
261261
torch.testing.assert_close(key.view(ref_key.size()),
262262
ref_key,
263263
atol=DEFAULT_ATOL,
264-
rtol=DEFAULT_RTOL)
264+
rtol=DEFAULT_RTOL)

vllm_ascend/ops/rotary_embedding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,8 @@ def rope_forward_oot(
7070
key,
7171
self.cos_sin_cache,
7272
self.head_size,
73-
mrope_section=[0,0,0],
74-
rotary_mode='half' if neox_style else 'interleave'
75-
)
73+
mrope_section=[0, 0, 0],
74+
rotary_mode='half' if neox_style else 'interleave')
7675
return query.view(query_shape), key.view(key_shape)
7776

7877

0 commit comments

Comments
 (0)