Skip to content

Commit f073388

Browse files
committed
Remove transformers_utils/configs/bailing_moe.py and move the load_weights to BailingMoeModel
Signed-off-by: vito.yy <vito.yy@antgroup.com>
1 parent 965fc90 commit f073388

File tree

3 files changed

+82
-160
lines changed

3 files changed

+82
-160
lines changed

vllm/model_executor/models/bailing_moe.py

Lines changed: 82 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import torch
3030
from torch import nn
31+
import torch.nn.functional as F
3132

3233
from vllm.attention import Attention
3334
from vllm.config import CacheConfig, VllmConfig
@@ -54,7 +55,7 @@
5455
from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig
5556

5657
from .interfaces import SupportsLoRA, SupportsPP
57-
from .utils import (PPMissingLayer, is_pp_missing_parameter,
58+
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
5859
make_empty_intermediate_tensors_factory, make_layers,
5960
maybe_prefix)
6061

@@ -377,6 +378,80 @@ def forward(
377378

378379
hidden_states, _ = self.norm(hidden_states, residual)
379380
return hidden_states
381+
382+
def load_weights(self, weights: Iterable[tuple[str,
383+
torch.Tensor]]) -> set[str]:
384+
stacked_params_mapping = [
385+
# (param_name, shard_name, shard_id)
386+
("gate_up_proj", "gate_proj", 0),
387+
("gate_up_proj", "up_proj", 1),
388+
]
389+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
390+
ckpt_gate_proj_name="gate_proj",
391+
ckpt_down_proj_name="down_proj",
392+
ckpt_up_proj_name="up_proj",
393+
num_experts=self.config.num_experts)
394+
395+
params_dict = dict(self.named_parameters(remove_duplicate=False))
396+
loaded_params: set[str] = set()
397+
for name, loaded_weight in weights:
398+
if self.config.norm_head and "lm_head.weight" in name:
399+
loaded_weight = F.normalize(loaded_weight,
400+
dim=0,
401+
p=2,
402+
eps=1e-7)
403+
404+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
405+
if weight_name not in name:
406+
continue
407+
if "mlp.experts" in name:
408+
continue
409+
name = name.replace(weight_name, param_name)
410+
# Skip loading extra bias for GPTQ models.
411+
if name.endswith(".bias") and name not in params_dict:
412+
continue
413+
if name not in params_dict:
414+
continue
415+
416+
if is_pp_missing_parameter(name, self):
417+
continue
418+
419+
param = params_dict[name]
420+
weight_loader = param.weight_loader
421+
weight_loader(param, loaded_weight, shard_id)
422+
break
423+
else:
424+
for mapping in expert_params_mapping:
425+
param_name, weight_name, expert_id, shard_id = mapping
426+
if weight_name not in name:
427+
continue
428+
name = name.replace(weight_name, param_name)
429+
430+
if is_pp_missing_parameter(name, self):
431+
continue
432+
param = params_dict[name]
433+
weight_loader = param.weight_loader
434+
weight_loader(param,
435+
loaded_weight,
436+
name,
437+
shard_id=shard_id,
438+
expert_id=expert_id)
439+
break
440+
else:
441+
if name.endswith(".bias") and name not in params_dict:
442+
continue
443+
if name not in params_dict:
444+
continue
445+
446+
if is_pp_missing_parameter(name, self):
447+
continue
448+
449+
param = params_dict[name]
450+
weight_loader = getattr(param, "weight_loader",
451+
default_weight_loader)
452+
weight_loader(param, loaded_weight)
453+
loaded_params.add(name)
454+
return loaded_params
380455

381456

382457
class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
@@ -463,78 +538,10 @@ def sample(
463538

464539
def load_weights(self, weights: Iterable[tuple[str,
465540
torch.Tensor]]) -> set[str]:
466-
stacked_params_mapping = [
467-
# (param_name, shard_name, shard_id)
468-
("gate_up_proj", "gate_proj", 0),
469-
("gate_up_proj", "up_proj", 1),
470-
]
471-
expert_params_mapping = FusedMoE.make_expert_params_mapping(
472-
ckpt_gate_proj_name="gate_proj",
473-
ckpt_down_proj_name="down_proj",
474-
ckpt_up_proj_name="up_proj",
475-
num_experts=self.config.num_experts)
476-
477-
params_dict = dict(self.named_parameters(remove_duplicate=False))
478-
loaded_params: set[str] = set()
479-
for name, loaded_weight in weights:
480-
if (("v_head" in name) or ("inv_freq" in name) or
481-
(self.config.tie_word_embeddings and "lm_head" in name)):
482-
continue
483-
if self.config.norm_head and "lm_head.weight" in name:
484-
import torch.nn.functional as F
485-
loaded_weight = F.normalize(loaded_weight,
486-
dim=0,
487-
p=2,
488-
eps=1e-7)
489-
490-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
491-
if weight_name not in name:
492-
continue
493-
if "mlp.experts" in name:
494-
continue
495-
name = name.replace(weight_name, param_name)
496-
# Skip loading extra bias for GPTQ models.
497-
if name.endswith(".bias") and name not in params_dict:
498-
continue
499-
if name not in params_dict:
500-
continue
501-
502-
if is_pp_missing_parameter(name, self):
503-
continue
504-
505-
param = params_dict[name]
506-
weight_loader = param.weight_loader
507-
weight_loader(param, loaded_weight, shard_id)
508-
break
509-
else:
510-
for mapping in expert_params_mapping:
511-
param_name, weight_name, expert_id, shard_id = mapping
512-
if weight_name not in name:
513-
continue
514-
name = name.replace(weight_name, param_name)
515-
516-
if is_pp_missing_parameter(name, self):
517-
continue
518-
param = params_dict[name]
519-
weight_loader = param.weight_loader
520-
weight_loader(param,
521-
loaded_weight,
522-
name,
523-
shard_id=shard_id,
524-
expert_id=expert_id)
525-
break
526-
else:
527-
if name.endswith(".bias") and name not in params_dict:
528-
continue
529-
if name not in params_dict:
530-
continue
531-
532-
if is_pp_missing_parameter(name, self):
533-
continue
541+
loader = AutoWeightsLoader(
542+
self,
543+
skip_prefixes=(["lm_head."]
544+
if self.config.tie_word_embeddings else None),
545+
)
546+
return loader.load_weights(weights)
534547

535-
param = params_dict[name]
536-
weight_loader = getattr(param, "weight_loader",
537-
default_weight_loader)
538-
weight_loader(param, loaded_weight)
539-
loaded_params.add(name)
540-
return loaded_params

vllm/transformers_utils/configs/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig
54
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
65
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
76
from vllm.transformers_utils.configs.dbrx import DbrxConfig
@@ -31,7 +30,6 @@
3130
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
3231

3332
__all__ = [
34-
"BailingMoeConfig",
3533
"ChatGLMConfig",
3634
"Cohere2Config",
3735
"DbrxConfig",

vllm/transformers_utils/configs/bailing_moe.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

0 commit comments

Comments
 (0)