|
25 | 25 | # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
26 | 26 | # """Inference-only DeepseekV2/DeepseekV3 model."""
|
27 | 27 |
|
28 |
| -from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
| 28 | +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union |
29 | 29 |
|
30 | 30 | import torch
|
31 | 31 | import torch_npu
|
|
55 | 55 | from vllm.model_executor.layers.sampler import get_sampler
|
56 | 56 | from vllm.model_executor.layers.vocab_parallel_embedding import (
|
57 | 57 | ParallelLMHead, VocabParallelEmbedding)
|
| 58 | +from vllm.model_executor.model_loader.weight_utils import ( |
| 59 | + default_weight_loader, maybe_remap_kv_scale_name) |
58 | 60 | from vllm.model_executor.models.deepseek_v2 import \
|
59 | 61 | DeepseekV2ForCausalLM # noqa: E501
|
60 | 62 | from vllm.model_executor.models.deepseek_v2 import \
|
61 | 63 | yarn_get_mscale # noqa: E501
|
62 |
| -from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, |
63 |
| - DeepseekV2DecoderLayer, |
64 |
| - DeepseekV2MLAAttention) |
| 64 | +from vllm.model_executor.models.deepseek_v2 import ( |
| 65 | + DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention, |
| 66 | + get_spec_layer_idx_from_weight_name) |
65 | 67 | from vllm.model_executor.models.utils import (
|
66 |
| - PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers, |
67 |
| - maybe_prefix) |
| 68 | + PPMissingLayer, is_pp_missing_parameter, |
| 69 | + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) |
68 | 70 | from vllm.sequence import IntermediateTensors
|
69 | 71 |
|
70 | 72 | from vllm_ascend.ascend_config import get_ascend_config
|
|
73 | 75 | from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
74 | 76 | from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
75 | 77 | from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
|
76 |
| - npu_wait_tensor) |
| 78 | + npu_wait_tensor, vllm_version_is) |
77 | 79 |
|
78 | 80 |
|
79 | 81 | class CustomDeepseekV2SiluAndMul(SiluAndMul):
|
@@ -867,6 +869,107 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
867 | 869 | self.make_empty_intermediate_tensors = (
|
868 | 870 | self.model.make_empty_intermediate_tensors)
|
869 | 871 |
|
| 872 | + # NOTE: This `load_weights` is mainly copied from |
| 873 | + # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 |
| 874 | + # to fix CI, and it is different from the implementation in main |
| 875 | + # TODO: support eplb style load_weights |
| 876 | + def load_weights(self, weights: Iterable[tuple[str, |
| 877 | + torch.Tensor]]) -> set[str]: |
| 878 | + """""" |
| 879 | + stacked_params_mapping = [ |
| 880 | + # (param_name, shard_name, shard_id) |
| 881 | + ("gate_up_proj", "gate_proj", 0), |
| 882 | + ("gate_up_proj", "up_proj", 1), |
| 883 | + ] |
| 884 | + |
| 885 | + # Params for weights, fp8 weight scales, fp8 activation scales |
| 886 | + # (param_name, weight_name, expert_id, shard_id) |
| 887 | + expert_params_mapping = AscendFusedMoE.make_expert_params_mapping( |
| 888 | + ckpt_gate_proj_name="gate_proj", |
| 889 | + ckpt_down_proj_name="down_proj", |
| 890 | + ckpt_up_proj_name="up_proj", |
| 891 | + num_experts=self.config.n_routed_experts) |
| 892 | + |
| 893 | + params_dict = dict(self.named_parameters()) |
| 894 | + loaded_params: set[str] = set() |
| 895 | + for name, loaded_weight in weights: |
| 896 | + if "rotary_emb.inv_freq" in name: |
| 897 | + continue |
| 898 | + |
| 899 | + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) |
| 900 | + if spec_layer is not None: |
| 901 | + continue # skip spec decode layers for main model |
| 902 | + |
| 903 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 904 | + # Skip non-stacked layers and experts (experts handled below). |
| 905 | + if weight_name not in name: |
| 906 | + continue |
| 907 | + # We have mlp.experts[0].gate_proj in the checkpoint. |
| 908 | + # Since we handle the experts below in expert_params_mapping, |
| 909 | + # we need to skip here BEFORE we update the name, otherwise |
| 910 | + # name will be updated to mlp.experts[0].gate_up_proj, which |
| 911 | + # will then be updated below in expert_params_mapping |
| 912 | + # for mlp.experts[0].gate_gate_up_proj, which breaks load. |
| 913 | + if (("mlp.experts." in name) and name not in params_dict): |
| 914 | + continue |
| 915 | + name = name.replace(weight_name, param_name) |
| 916 | + # Skip loading extra bias for GPTQ models. |
| 917 | + if name.endswith(".bias") and name not in params_dict: |
| 918 | + continue |
| 919 | + |
| 920 | + if is_pp_missing_parameter(name, self): |
| 921 | + continue |
| 922 | + |
| 923 | + param = params_dict[name] |
| 924 | + weight_loader = param.weight_loader |
| 925 | + weight_loader(param, loaded_weight, shard_id) |
| 926 | + break |
| 927 | + else: |
| 928 | + for mapping in expert_params_mapping: |
| 929 | + param_name, weight_name, expert_id, shard_id = mapping |
| 930 | + if weight_name not in name: |
| 931 | + continue |
| 932 | + name = name.replace(weight_name, param_name) |
| 933 | + |
| 934 | + if is_pp_missing_parameter(name, self): |
| 935 | + continue |
| 936 | + |
| 937 | + param = params_dict[name] |
| 938 | + weight_loader = param.weight_loader |
| 939 | + if vllm_version_is("0.9.1"): |
| 940 | + weight_loader(param, |
| 941 | + loaded_weight, |
| 942 | + name, |
| 943 | + shard_id=shard_id, |
| 944 | + expert_id=expert_id) |
| 945 | + else: |
| 946 | + weight_loader(param, |
| 947 | + loaded_weight, |
| 948 | + name, |
| 949 | + shard_id=shard_id, |
| 950 | + expert_id=expert_id, |
| 951 | + return_success=False) |
| 952 | + break |
| 953 | + else: |
| 954 | + # Skip loading extra bias for GPTQ models. |
| 955 | + if name.endswith(".bias") and name not in params_dict: |
| 956 | + continue |
| 957 | + |
| 958 | + # Remapping the name of FP8 kv-scale. |
| 959 | + name = maybe_remap_kv_scale_name(name, params_dict) |
| 960 | + if name is None: |
| 961 | + continue |
| 962 | + |
| 963 | + if is_pp_missing_parameter(name, self): |
| 964 | + continue |
| 965 | + |
| 966 | + param = params_dict[name] |
| 967 | + weight_loader = getattr(param, "weight_loader", |
| 968 | + default_weight_loader) |
| 969 | + weight_loader(param, loaded_weight) |
| 970 | + loaded_params.add(name) |
| 971 | + return loaded_params |
| 972 | + |
870 | 973 | def forward(
|
871 | 974 | self,
|
872 | 975 | input_ids: torch.Tensor,
|
|
0 commit comments