Skip to content

Commit d59e7fa

Browse files
MengqingCaoYikun
andauthored
[CI] Pin transformers<4.53.0 and fix EPLB load_weights to make CI passed (#1482)
### What this PR does / why we need it? - Fix vLLM EPLB break vllm-project/vllm@e9fd658 by recovering load_weights back to [v0.9.1 version](vllm-project/vllm@07b8fae) temporarily. - Fix transformers>=4.53.0 image processor break Related: #1470 - Mirror torch_npu requirements to pyproject.toml ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed --------- Signed-off-by: MengqingCao <cmq0113@163.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
1 parent 3687676 commit d59e7fa

File tree

4 files changed

+226
-15
lines changed

4 files changed

+226
-15
lines changed

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ requires = [
1212
"scipy",
1313
"setuptools>=64",
1414
"setuptools-scm>=8",
15-
"torch-npu==2.5.1.post1.dev20250528",
15+
"torch-npu==2.5.1.post1.dev20250619",
1616
"torch>=2.5.1",
1717
"torchvision<0.21.0",
1818
"wheel",
1919
"msgpack",
2020
"quart",
2121
"numba",
22+
# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470
23+
"transformers<4.53.0",
2224
]
2325
build-backend = "setuptools.build_meta"

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ numba
2525
--pre
2626
--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
2727
torch-npu==2.5.1.post1.dev20250619
28+
29+
# Remove after https://github.com/vllm-project/vllm-ascend/issues/1470
30+
transformers<4.53.0

vllm_ascend/models/deepseek_dbo.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
2626
# """Inference-only DeepseekV2/DeepseekV3 model."""
2727

28-
from typing import Any, Dict, List, Optional, Union
28+
from typing import Any, Dict, Iterable, List, Optional, Union
2929

3030
import torch
3131
import torch.distributed as dist
@@ -49,16 +49,18 @@
4949
from vllm.model_executor.layers.sampler import get_sampler
5050
from vllm.model_executor.layers.vocab_parallel_embedding import (
5151
ParallelLMHead, VocabParallelEmbedding)
52+
from vllm.model_executor.model_loader.weight_utils import (
53+
default_weight_loader, maybe_remap_kv_scale_name)
5254
from vllm.model_executor.models.deepseek_v2 import \
5355
DeepseekV2ForCausalLM # noqa: E501
5456
from vllm.model_executor.models.deepseek_v2 import \
5557
yarn_get_mscale # noqa: E501
56-
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
57-
DeepseekV2DecoderLayer,
58-
DeepseekV2MLAAttention)
58+
from vllm.model_executor.models.deepseek_v2 import (
59+
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention,
60+
get_spec_layer_idx_from_weight_name)
5961
from vllm.model_executor.models.utils import (
60-
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
61-
maybe_prefix)
62+
PPMissingLayer, is_pp_missing_parameter,
63+
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
6264
from vllm.sequence import IntermediateTensors
6365

6466
import vllm_ascend.envs as envs_ascend
@@ -76,7 +78,7 @@
7678
make_multistream_metadata_ds)
7779
from vllm_ascend.multistream.ms_split import compute_split_seq_index
7880
from vllm_ascend.ops.fused_moe import AscendFusedMoE
79-
from vllm_ascend.utils import dispose_tensor
81+
from vllm_ascend.utils import dispose_tensor, vllm_version_is
8082

8183
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
8284

@@ -963,6 +965,107 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
963965
self.make_empty_intermediate_tensors = (
964966
self.model.make_empty_intermediate_tensors)
965967

968+
# NOTE: This `load_weights` is mainly copied from
969+
# https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5
970+
# to fix CI, and it is different from the implementation in main
971+
# TODO: support eplb style load_weights
972+
def load_weights(self, weights: Iterable[tuple[str,
973+
torch.Tensor]]) -> set[str]:
974+
""""""
975+
stacked_params_mapping = [
976+
# (param_name, shard_name, shard_id)
977+
("gate_up_proj", "gate_proj", 0),
978+
("gate_up_proj", "up_proj", 1),
979+
]
980+
981+
# Params for weights, fp8 weight scales, fp8 activation scales
982+
# (param_name, weight_name, expert_id, shard_id)
983+
expert_params_mapping = AscendFusedMoE.make_expert_params_mapping(
984+
ckpt_gate_proj_name="gate_proj",
985+
ckpt_down_proj_name="down_proj",
986+
ckpt_up_proj_name="up_proj",
987+
num_experts=self.config.n_routed_experts)
988+
989+
params_dict = dict(self.named_parameters())
990+
loaded_params: set[str] = set()
991+
for name, loaded_weight in weights:
992+
if "rotary_emb.inv_freq" in name:
993+
continue
994+
995+
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
996+
if spec_layer is not None:
997+
continue # skip spec decode layers for main model
998+
999+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
1000+
# Skip non-stacked layers and experts (experts handled below).
1001+
if weight_name not in name:
1002+
continue
1003+
# We have mlp.experts[0].gate_proj in the checkpoint.
1004+
# Since we handle the experts below in expert_params_mapping,
1005+
# we need to skip here BEFORE we update the name, otherwise
1006+
# name will be updated to mlp.experts[0].gate_up_proj, which
1007+
# will then be updated below in expert_params_mapping
1008+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
1009+
if (("mlp.experts." in name) and name not in params_dict):
1010+
continue
1011+
name = name.replace(weight_name, param_name)
1012+
# Skip loading extra bias for GPTQ models.
1013+
if name.endswith(".bias") and name not in params_dict:
1014+
continue
1015+
1016+
if is_pp_missing_parameter(name, self):
1017+
continue
1018+
1019+
param = params_dict[name]
1020+
weight_loader = param.weight_loader
1021+
weight_loader(param, loaded_weight, shard_id)
1022+
break
1023+
else:
1024+
for mapping in expert_params_mapping:
1025+
param_name, weight_name, expert_id, shard_id = mapping
1026+
if weight_name not in name:
1027+
continue
1028+
name = name.replace(weight_name, param_name)
1029+
1030+
if is_pp_missing_parameter(name, self):
1031+
continue
1032+
1033+
param = params_dict[name]
1034+
weight_loader = param.weight_loader
1035+
if vllm_version_is("0.9.1"):
1036+
weight_loader(param,
1037+
loaded_weight,
1038+
name,
1039+
shard_id=shard_id,
1040+
expert_id=expert_id)
1041+
else:
1042+
weight_loader(param,
1043+
loaded_weight,
1044+
name,
1045+
shard_id=shard_id,
1046+
expert_id=expert_id,
1047+
return_success=False)
1048+
break
1049+
else:
1050+
# Skip loading extra bias for GPTQ models.
1051+
if name.endswith(".bias") and name not in params_dict:
1052+
continue
1053+
1054+
# Remapping the name of FP8 kv-scale.
1055+
name = maybe_remap_kv_scale_name(name, params_dict)
1056+
if name is None:
1057+
continue
1058+
1059+
if is_pp_missing_parameter(name, self):
1060+
continue
1061+
1062+
param = params_dict[name]
1063+
weight_loader = getattr(param, "weight_loader",
1064+
default_weight_loader)
1065+
weight_loader(param, loaded_weight)
1066+
loaded_params.add(name)
1067+
return loaded_params
1068+
9661069
def forward(
9671070
self,
9681071
input_ids: torch.Tensor,

vllm_ascend/models/deepseek_v2.py

Lines changed: 110 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
2626
# """Inference-only DeepseekV2/DeepseekV3 model."""
2727

28-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
28+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
2929

3030
import torch
3131
import torch_npu
@@ -55,16 +55,18 @@
5555
from vllm.model_executor.layers.sampler import get_sampler
5656
from vllm.model_executor.layers.vocab_parallel_embedding import (
5757
ParallelLMHead, VocabParallelEmbedding)
58+
from vllm.model_executor.model_loader.weight_utils import (
59+
default_weight_loader, maybe_remap_kv_scale_name)
5860
from vllm.model_executor.models.deepseek_v2 import \
5961
DeepseekV2ForCausalLM # noqa: E501
6062
from vllm.model_executor.models.deepseek_v2 import \
6163
yarn_get_mscale # noqa: E501
62-
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
63-
DeepseekV2DecoderLayer,
64-
DeepseekV2MLAAttention)
64+
from vllm.model_executor.models.deepseek_v2 import (
65+
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention,
66+
get_spec_layer_idx_from_weight_name)
6567
from vllm.model_executor.models.utils import (
66-
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
67-
maybe_prefix)
68+
PPMissingLayer, is_pp_missing_parameter,
69+
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
6870
from vllm.sequence import IntermediateTensors
6971

7072
from vllm_ascend.ascend_config import get_ascend_config
@@ -73,7 +75,7 @@
7375
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7476
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7577
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
76-
npu_wait_tensor)
78+
npu_wait_tensor, vllm_version_is)
7779

7880

7981
class CustomDeepseekV2SiluAndMul(SiluAndMul):
@@ -867,6 +869,107 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
867869
self.make_empty_intermediate_tensors = (
868870
self.model.make_empty_intermediate_tensors)
869871

872+
# NOTE: This `load_weights` is mainly copied from
873+
# https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5
874+
# to fix CI, and it is different from the implementation in main
875+
# TODO: support eplb style load_weights
876+
def load_weights(self, weights: Iterable[tuple[str,
877+
torch.Tensor]]) -> set[str]:
878+
""""""
879+
stacked_params_mapping = [
880+
# (param_name, shard_name, shard_id)
881+
("gate_up_proj", "gate_proj", 0),
882+
("gate_up_proj", "up_proj", 1),
883+
]
884+
885+
# Params for weights, fp8 weight scales, fp8 activation scales
886+
# (param_name, weight_name, expert_id, shard_id)
887+
expert_params_mapping = AscendFusedMoE.make_expert_params_mapping(
888+
ckpt_gate_proj_name="gate_proj",
889+
ckpt_down_proj_name="down_proj",
890+
ckpt_up_proj_name="up_proj",
891+
num_experts=self.config.n_routed_experts)
892+
893+
params_dict = dict(self.named_parameters())
894+
loaded_params: set[str] = set()
895+
for name, loaded_weight in weights:
896+
if "rotary_emb.inv_freq" in name:
897+
continue
898+
899+
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
900+
if spec_layer is not None:
901+
continue # skip spec decode layers for main model
902+
903+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
904+
# Skip non-stacked layers and experts (experts handled below).
905+
if weight_name not in name:
906+
continue
907+
# We have mlp.experts[0].gate_proj in the checkpoint.
908+
# Since we handle the experts below in expert_params_mapping,
909+
# we need to skip here BEFORE we update the name, otherwise
910+
# name will be updated to mlp.experts[0].gate_up_proj, which
911+
# will then be updated below in expert_params_mapping
912+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
913+
if (("mlp.experts." in name) and name not in params_dict):
914+
continue
915+
name = name.replace(weight_name, param_name)
916+
# Skip loading extra bias for GPTQ models.
917+
if name.endswith(".bias") and name not in params_dict:
918+
continue
919+
920+
if is_pp_missing_parameter(name, self):
921+
continue
922+
923+
param = params_dict[name]
924+
weight_loader = param.weight_loader
925+
weight_loader(param, loaded_weight, shard_id)
926+
break
927+
else:
928+
for mapping in expert_params_mapping:
929+
param_name, weight_name, expert_id, shard_id = mapping
930+
if weight_name not in name:
931+
continue
932+
name = name.replace(weight_name, param_name)
933+
934+
if is_pp_missing_parameter(name, self):
935+
continue
936+
937+
param = params_dict[name]
938+
weight_loader = param.weight_loader
939+
if vllm_version_is("0.9.1"):
940+
weight_loader(param,
941+
loaded_weight,
942+
name,
943+
shard_id=shard_id,
944+
expert_id=expert_id)
945+
else:
946+
weight_loader(param,
947+
loaded_weight,
948+
name,
949+
shard_id=shard_id,
950+
expert_id=expert_id,
951+
return_success=False)
952+
break
953+
else:
954+
# Skip loading extra bias for GPTQ models.
955+
if name.endswith(".bias") and name not in params_dict:
956+
continue
957+
958+
# Remapping the name of FP8 kv-scale.
959+
name = maybe_remap_kv_scale_name(name, params_dict)
960+
if name is None:
961+
continue
962+
963+
if is_pp_missing_parameter(name, self):
964+
continue
965+
966+
param = params_dict[name]
967+
weight_loader = getattr(param, "weight_loader",
968+
default_weight_loader)
969+
weight_loader(param, loaded_weight)
970+
loaded_params.add(name)
971+
return loaded_params
972+
870973
def forward(
871974
self,
872975
input_ids: torch.Tensor,

0 commit comments

Comments
 (0)