Skip to content

Commit 7058853

Browse files
committed
fix lint
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
1 parent c848786 commit 7058853

File tree

4 files changed

+24
-21
lines changed

4 files changed

+24
-21
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ def forward(
274274
shape = [batch_size * seq_len, num_heads, head_size]
275275
"""
276276
num_tokens = query.shape[0]
277-
use_kv_cache_int8 = len(kv_cache
278-
) > 0 and kv_cache[0].dtype == torch.int8
277+
use_kv_cache_int8 = len(
278+
kv_cache) > 0 and kv_cache[0].dtype == torch.int8
279279
if output is None:
280280
output = torch.empty(num_tokens,
281281
self.num_heads,

vllm_ascend/attention/mla_v1.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,8 +1020,7 @@ def _forward_decode(
10201020
num_kv_heads=self.num_kv_heads,
10211021
num_heads=self.num_heads,
10221022
scale_value=self.scale,
1023-
block_table=attn_metadata.decode.
1024-
block_table, # type:ignore
1023+
block_table=attn_metadata.decode.block_table, # type:ignore
10251024
context_lens=attn_metadata.decode.seq_lens, # type:ignore
10261025
mla_vheadsize=self.kv_lora_rank,
10271026
out=attn_output)

vllm_ascend/ops/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def vanilla_chunked_prefill(
138138
def vanilla_chunked_prefill_mla(
139139
output: torch.Tensor, # (num_tokens, num_heads, v_head_dim)
140140
query: torch.Tensor, # (num_tokens, num_heads, nope_dim + rope_dim)
141-
kv_c_and_k_pe_cache: tuple[torch.Tensor], # (num_blocks, block_size, latent_kv/rope_dim)
141+
kv_c_and_k_pe_cache: tuple[
142+
torch.Tensor], # (num_blocks, block_size, latent_kv/rope_dim)
142143
block_tables: torch.Tensor, # (batch_size, max_num_blocks_per_seq)
143144
query_lens: torch.Tensor, # (batch_size)
144145
context_lens: torch.Tensor, # (batch_size)
@@ -156,10 +157,9 @@ def vanilla_chunked_prefill_mla(
156157
num_heads = query.size(1)
157158
cache_kv_c = kv_c_and_k_pe_cache[0].squeeze()
158159
cache_k_pe = kv_c_and_k_pe_cache[1].squeeze()
159-
160+
160161
# cached_kv_c: [batch_size, max_context_len, latent_kv]
161162
# cached_k_pe: [batch_size, max_context_len, rope_dim]
162-
batch_size = query_lens.size(0)
163163
block_size, latent_kv_dim = cache_kv_c.size(1), cache_kv_c.size(-1)
164164
max_num_blocks_per_seq = block_tables.size(1)
165165
cache_kv_c = cache_kv_c[block_tables].view(

vllm_ascend/worker/model_runner_v1.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,19 +2138,23 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
21382138
nope_dim = self.model_config.hf_text_config.kv_lora_rank
21392139
assert head_dim == rope_dim + nope_dim, \
21402140
f"head_dim({head_dim}) != rope_dim({rope_dim}) + nope_dim({nope_dim})"
2141-
nope_cache_shape = (num_blocks, block_size,
2142-
num_kv_heads, nope_dim)
2143-
rope_cache_shape = (num_blocks, block_size,
2144-
num_kv_heads, rope_dim)
2145-
nope_cache = torch.zeros(
2146-
nope_cache_shape, dtype=dtype, device=self.device,
2147-
pin_memory=True)
2148-
rope_cache = torch.zeros(
2149-
rope_cache_shape, dtype=dtype, device=self.device,
2150-
pin_memory=True)
2141+
layer_kv_cache_nope_shape = (num_blocks, block_size,
2142+
num_kv_heads, nope_dim)
2143+
layer_kv_cache_pe_shape = (num_blocks, block_size,
2144+
num_kv_heads, rope_dim)
2145+
layer_kv_cache_nope = torch.zeros(
2146+
layer_kv_cache_nope_shape,
2147+
dtype=dtype,
2148+
device=self.device)
2149+
layer_kv_cache_pe = torch.zeros(
2150+
layer_kv_cache_pe_shape,
2151+
dtype=dtype,
2152+
device=self.device)
21512153
kv_caches[layer_name] = (
2152-
torch_npu.npu_format_cast(nope_cache, acl_format),
2153-
torch_npu.npu_format_cast(rope_cache, acl_format),
2154+
torch_npu.npu_format_cast(layer_kv_cache_nope,
2155+
acl_format),
2156+
torch_npu.npu_format_cast(layer_kv_cache_pe,
2157+
acl_format),
21542158
)
21552159
else:
21562160
num_caches = kv_cache_shape[0]
@@ -2160,8 +2164,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
21602164
kv_cache = torch.zeros(cache_shape,
21612165
dtype=dtype,
21622166
device=self.device)
2163-
kv_cache = torch_npu.npu_format_cast(kv_cache,
2164-
acl_format)
2167+
kv_cache = torch_npu.npu_format_cast(
2168+
kv_cache, acl_format)
21652169
kv_cache_list.append(kv_cache)
21662170
kv_caches[layer_name] = kv_cache_list
21672171
else:

0 commit comments

Comments
 (0)