Skip to content

Commit b308a7a

Browse files
authored
support pangumoe w8a8c8 and docs (#1477)
### What this PR does / why we need it? support pangu moe w8a8c8 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed with new added test. Signed-off-by: zhuyilin <809721801@qq.com>
1 parent c59d69d commit b308a7a

File tree

8 files changed

+689
-50
lines changed

8 files changed

+689
-50
lines changed

docs/source/user_guide/additional_config.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ The following table lists the additional configuration options available in vLLM
3232
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
3333
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
3434
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
35+
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
3536

3637
The details of each config option are as follows:
3738

vllm_ascend/attention/attention_v1.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ def get_kv_cache_shape(
6969
16)
7070
return (2, num_blocks, block_size, num_kv_heads, head_size)
7171

72+
@staticmethod
73+
def get_bsh_kv_cache_shape(
74+
num_blocks: int,
75+
block_size: int,
76+
num_kv_heads: int,
77+
head_size: int,
78+
) -> Tuple[int, ...]:
79+
return (2, num_blocks, block_size, num_kv_heads * head_size)
80+
7281
@staticmethod
7382
def swap_blocks(
7483
src_kv_cache: List[torch.Tensor],
@@ -279,6 +288,13 @@ def forward(
279288
value=value,
280289
output=output,
281290
layer_name=layer.layer_name)
291+
292+
elif hasattr(layer, 'quant_method'):
293+
output = layer.quant_method.apply(layer, query, key, value,
294+
kv_cache, attn_metadata,
295+
self.attn_type, self.scale,
296+
output)
297+
282298
else:
283299
if attn_metadata is None:
284300
return output.view(num_tokens, self.hidden_size)
@@ -308,11 +324,8 @@ def forward(
308324
value_cache=self.value_cache,
309325
slot_indices=slots)
310326

311-
if hasattr(layer, 'quant_method'):
312-
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
313-
pass
314327
# V0-Style scheduler situation.
315-
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
328+
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
316329
assert attn_metadata is not None
317330
assert attn_metadata.attn_mask is not None
318331
mask = attn_metadata.attn_mask
@@ -414,6 +427,8 @@ def forward(
414427
out=output)
415428

416429
# to make in-place change to the output tensor
430+
if hasattr(layer, 'quant_method'):
431+
output = output.view(num_tokens, self.num_heads, self.head_size)
417432
ori_output[:, :, :] = output[:num_tokens, :, :]
418433
return output.view(num_tokens, self.hidden_size)
419434

vllm_ascend/models/pangu_moe.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def forward(
505505
# native FusedMoE. here we need to design a better FusedMoE
506506
# (maybe using AscendFusedMoE) to enable these different
507507
# communication schema.
508-
final_hidden_states = self.experts.quant_method(
508+
final_hidden_states = self.experts.quant_method.apply(
509509
layer=self.experts,
510510
x=hidden_states,
511511
router_logits=router_logits,
@@ -937,6 +937,8 @@ def sample(
937937
return next_tokens
938938

939939
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
940+
tp_size = get_tp_group().world_size
941+
tp_rank = get_tp_group().rank_in_group
940942
stacked_params_mapping = [
941943
# (param_name, shard_name, shard_id)
942944
("qkv_proj", "q_proj", "q"),
@@ -972,6 +974,51 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
972974
if "module" in name:
973975
continue
974976

977+
if name.endswith('kv_cache_offset'):
978+
continue
979+
980+
if name.endswith("k_proj.kv_cache_scale"):
981+
remapped_kv_scale_name = name.replace(
982+
"k_proj.kv_cache_scale", "attn.key_antiquant_scale")
983+
if remapped_kv_scale_name not in params_dict:
984+
logger.warning_once(
985+
"Found kv scale in the checkpoint "
986+
f"(e.g. {name}), but not found the expected "
987+
f"name in the model "
988+
f"(e.g. {remapped_kv_scale_name}). "
989+
"kv-scale is not loaded.")
990+
continue
991+
else:
992+
name = remapped_kv_scale_name
993+
param = params_dict[name]
994+
loaded_weight = torch.tensor_split(loaded_weight,
995+
tp_size,
996+
dim=0)[tp_rank]
997+
weight_loader = getattr(param, "weight_loader",
998+
default_weight_loader)
999+
weight_loader(param, loaded_weight)
1000+
1001+
if name.endswith("v_proj.kv_cache_scale"):
1002+
remapped_kv_scale_name = name.replace(
1003+
"v_proj.kv_cache_scale", "attn.value_antiquant_scale")
1004+
if remapped_kv_scale_name not in params_dict:
1005+
logger.warning_once(
1006+
"Found kv scale in the checkpoint "
1007+
f"(e.g. {name}), but not found the expected "
1008+
f"name in the model "
1009+
f"(e.g. {remapped_kv_scale_name}). "
1010+
"kv-scale is not loaded.")
1011+
continue
1012+
else:
1013+
name = remapped_kv_scale_name
1014+
param = params_dict[name]
1015+
loaded_weight = torch.tensor_split(loaded_weight,
1016+
tp_size,
1017+
dim=0)[tp_rank]
1018+
weight_loader = getattr(param, "weight_loader",
1019+
default_weight_loader)
1020+
weight_loader(param, loaded_weight)
1021+
9751022
for (param_name, weight_name, shard_id) in stacked_params_mapping:
9761023
# Skip non-stacked layers and experts (experts handled below).
9771024
if weight_name not in name:

vllm_ascend/platform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
124124
model_config = vllm_config.model_config
125125
parallel_config = vllm_config.parallel_config
126126
cache_config = vllm_config.cache_config
127+
kv_cache_dtype = vllm_config.additional_config.get(
128+
"kv_cache_dtype", None)
129+
if kv_cache_dtype is not None:
130+
vllm_config.cache_config.cache_dtype = kv_cache_dtype
127131

128132
if parallel_config:
129133
# Default value for expert tensor parallel size

vllm_ascend/quantization/quant_config.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def get_quant_method(self, layer: torch.nn.Module,
9898
'fa_quant_type' in self.quant_description.keys() and \
9999
self.quant_description['fa_quant_type'] is not None:
100100
return AscendKVCacheMethod(self, prefix)
101+
elif isinstance(layer, Attention) and self.quant_description.get(
102+
'kv_quant_type') == 'C8':
103+
return AscendKVCacheMethod(self, prefix)
101104
elif isinstance(layer, FusedMoE):
102105
if self.is_layer_skipped_ascend(prefix,
103106
self.packed_modules_mapping):
@@ -235,32 +238,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
235238
if hasattr(self.quant_method, "process_weights_after_loading"):
236239
self.quant_method.process_weights_after_loading(layer)
237240

238-
def apply(self,
239-
layer: torch.nn.Module,
240-
query: torch.Tensor,
241-
key: torch.Tensor,
242-
value: torch.Tensor,
243-
k_cache: List[torch.Tensor],
244-
v_cache: List[torch.Tensor],
245-
scale: torch.Tensor,
246-
block_tables: torch.Tensor,
247-
isPrefill: bool,
248-
attn_metadata,
249-
output,
250-
seq_lens_tensor_cpu: Optional[int] = None) -> torch.Tensor:
251-
return self.quant_method.apply(layer,
252-
query,
253-
key,
254-
value,
255-
k_cache,
256-
v_cache,
257-
scale,
258-
block_tables,
259-
isPrefill,
260-
attn_metadata.attn_mask,
261-
attn_metadata.slot_mapping,
262-
output,
263-
seq_lens_tensor_cpu=seq_lens_tensor_cpu)
241+
def apply(self, layer: torch.nn.Module, query: torch.Tensor,
242+
key: torch.Tensor, value: torch.Tensor, kv_cache, attn_metadata,
243+
attn_type, scale, output) -> torch.Tensor:
244+
return self.quant_method.apply(layer, query, key, value, kv_cache,
245+
attn_metadata, attn_type, scale, output)
264246

265247

266248
class AscendFusedMoEMethod(FusedMoEMethodBase):

vllm_ascend/quantization/quantizer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
2626
wrapper_rmsnorm_init)
27-
from .w8a8 import AscendW8A8LinearMethod
27+
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
28+
AscendW8A8LinearMethod)
2829
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
2930
AscendW8A8DynamicLinearMethod)
3031

@@ -250,6 +251,8 @@ def get_quantizer(cls,
250251
# Attention
251252
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
252253
quant_type = quant_description['fa_quant_type']
254+
if '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
255+
quant_type = quant_description['kv_quant_type']
253256
# Linear
254257
else:
255258
quant_type = cls.get_linear_quant_type(quant_description, prefix,
@@ -269,6 +272,14 @@ class W8A8Quantizer(VLLMAscendQuantizer):
269272
def build_linear_method():
270273
return AscendW8A8LinearMethod()
271274

275+
@staticmethod
276+
def build_moe_method():
277+
return AscendW8A8FusedMoEMethod()
278+
279+
@staticmethod
280+
def build_attention_method():
281+
return AscendC8KVCacheMethod()
282+
272283

273284
class W8A8DYNAMICQuantizer(VLLMAscendQuantizer):
274285

@@ -284,4 +295,5 @@ def build_moe_method():
284295
SUPPORT_ASCEND_QUANTIZER_TYPE = {
285296
"W8A8": W8A8Quantizer,
286297
"W8A8_DYNAMIC": W8A8DYNAMICQuantizer,
298+
"C8": W8A8Quantizer,
287299
}

0 commit comments

Comments
 (0)