Skip to content

Commit 861701d

Browse files
committed
fix lint
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
1 parent 64a6343 commit 861701d

File tree

4 files changed

+24
-21
lines changed

4 files changed

+24
-21
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ def forward(
274274
shape = [batch_size * seq_len, num_heads, head_size]
275275
"""
276276
num_tokens = query.shape[0]
277-
use_kv_cache_int8 = len(kv_cache
278-
) > 0 and kv_cache[0].dtype == torch.int8
277+
use_kv_cache_int8 = len(
278+
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
279279
if output is None:
280280
output = torch.empty(num_tokens,
281281
self.num_heads,

vllm_ascend/attention/mla_v1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,8 +1022,7 @@ def _forward_decode(
10221022
num_kv_heads=self.num_kv_heads,
10231023
num_heads=self.num_heads,
10241024
scale_value=self.scale,
1025-
block_table=attn_metadata.decode.
1026-
block_table, # type:ignore
1025+
block_table=attn_metadata.decode.block_table, # type:ignore
10271026
context_lens=attn_metadata.decode.seq_lens, # type:ignore
10281027
mla_vheadsize=self.kv_lora_rank,
10291028
out=attn_output)

vllm_ascend/ops/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def vanilla_chunked_prefill(
138138
def vanilla_chunked_prefill_mla(
139139
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
140140
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
141-
kv_c_and_k_pe_cache: tuple[torch.Tensor], # (num_blocks, block_size, latent_kv/rope_dim)
141+
kv_c_and_k_pe_cache: tuple[
142+
torch.Tensor], # (num_blocks, block_size, latent_kv/rope_dim)
142143
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
143144
query_lens: torch.Tensor, # (batch_size)
144145
context_lens: torch.Tensor, # (batch_size)
@@ -156,10 +157,9 @@ def vanilla_chunked_prefill_mla(
156157
num_heads = query.size(1)
157158
cache_kv_c = kv_c_and_k_pe_cache[0].squeeze()
158159
cache_k_pe = kv_c_and_k_pe_cache[1].squeeze()
159-
160+
160161
# cached_kv_c: [batch_size, max_context_len, latent_kv]
161162
# cached_k_pe: [batch_size, max_context_len, rope_dim]
162-
batch_size = query_lens.size(0)
163163
block_size, latent_kv_dim = cache_kv_c.size(1), cache_kv_c.size(-1)
164164
max_num_blocks_per_seq = block_tables.size(1)
165165
cache_kv_c = cache_kv_c[block_tables].view(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,19 +1961,23 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
19611961
nope_dim = self.model_config.hf_text_config.kv_lora_rank
19621962
assert head_dim == rope_dim + nope_dim, \
19631963
f"head_dim({head_dim}) != rope_dim({rope_dim}) + nope_dim({nope_dim})"
1964-
nope_cache_shape = (num_blocks, block_size,
1965-
num_kv_heads, nope_dim)
1966-
rope_cache_shape = (num_blocks, block_size,
1967-
num_kv_heads, rope_dim)
1968-
nope_cache = torch.zeros(
1969-
nope_cache_shape, dtype=dtype, device=self.device,
1970-
pin_memory=True)
1971-
rope_cache = torch.zeros(
1972-
rope_cache_shape, dtype=dtype, device=self.device,
1973-
pin_memory=True)
1964+
layer_kv_cache_nope_shape = (num_blocks, block_size,
1965+
num_kv_heads, nope_dim)
1966+
layer_kv_cache_pe_shape = (num_blocks, block_size,
1967+
num_kv_heads, rope_dim)
1968+
layer_kv_cache_nope = torch.zeros(
1969+
layer_kv_cache_nope_shape,
1970+
dtype=dtype,
1971+
device=self.device)
1972+
layer_kv_cache_pe = torch.zeros(
1973+
layer_kv_cache_pe_shape,
1974+
dtype=dtype,
1975+
device=self.device)
19741976
kv_caches[layer_name] = (
1975-
torch_npu.npu_format_cast(nope_cache, acl_format),
1976-
torch_npu.npu_format_cast(rope_cache, acl_format),
1977+
torch_npu.npu_format_cast(layer_kv_cache_nope,
1978+
acl_format),
1979+
torch_npu.npu_format_cast(layer_kv_cache_pe,
1980+
acl_format),
19771981
)
19781982
>>>>>>> c848786 (use tuple as kv cache instead of tensor)
19791983
else:
@@ -1984,8 +1988,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
19841988
kv_cache = torch.zeros(cache_shape,
19851989
dtype=dtype,
19861990
device=self.device)
1987-
kv_cache = torch_npu.npu_format_cast(kv_cache,
1988-
acl_format)
1991+
kv_cache = torch_npu.npu_format_cast(
1992+
kv_cache, acl_format)
19891993
kv_cache_list.append(kv_cache)
19901994
kv_caches[layer_name] = kv_cache_list
19911995
else:

0 commit comments

Comments
 (0)