Skip to content

Commit dc6172e

Browse files
authored
update attention nz and mla nz(Improve TPOP 6ms performance) (#909)
### What this PR does / why we need it? Update attention nz and mla nz modules to improve TPOP 6ms performance Convert W_UV and W_UK_T to NPU format in mla_v1.py Convert layer.weight to NPU format in w8a8.py Signed-off-by: ttanzhiqiang <389825161@qq.com>
1 parent 7153d88 commit dc6172e

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,11 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
476476
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
477477

478478
# Convert from (L, N, V) to (N, L, V)
479-
self.W_UV = W_UV.transpose(0, 1)
479+
self.W_UV = W_UV.transpose(0, 1).contiguous()
480480
# Convert from (L, N, P) to (N, P, L)
481-
self.W_UK_T = W_UK.permute(1, 2, 0)
481+
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
482+
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
483+
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
482484

483485
def _forward_prefill(
484486
self,

vllm_ascend/quantization/w8a8.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,6 @@ def process_weights_after_loading(self, layer):
110110
requires_grad=False).to(layer.aclnn_input_scale.dtype)
111111
if self.transpose_weight:
112112
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
113+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
113114
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
114115
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)

0 commit comments

Comments
 (0)