|
28 | 28 |
|
29 | 29 | import torch
|
30 | 30 | from torch import nn
|
| 31 | +import torch.nn.functional as F |
31 | 32 |
|
32 | 33 | from vllm.attention import Attention
|
33 | 34 | from vllm.config import CacheConfig, VllmConfig
|
|
54 | 55 | from vllm.transformers_utils.configs.bailing_moe import BailingMoeConfig
|
55 | 56 |
|
56 | 57 | from .interfaces import SupportsLoRA, SupportsPP
|
57 |
| -from .utils import (PPMissingLayer, is_pp_missing_parameter, |
| 58 | +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, |
58 | 59 | make_empty_intermediate_tensors_factory, make_layers,
|
59 | 60 | maybe_prefix)
|
60 | 61 |
|
@@ -377,6 +378,80 @@ def forward(
|
377 | 378 |
|
378 | 379 | hidden_states, _ = self.norm(hidden_states, residual)
|
379 | 380 | return hidden_states
|
| 381 | + |
| 382 | + def load_weights(self, weights: Iterable[tuple[str, |
| 383 | + torch.Tensor]]) -> set[str]: |
| 384 | + stacked_params_mapping = [ |
| 385 | + # (param_name, shard_name, shard_id) |
| 386 | + ("gate_up_proj", "gate_proj", 0), |
| 387 | + ("gate_up_proj", "up_proj", 1), |
| 388 | + ] |
| 389 | + expert_params_mapping = FusedMoE.make_expert_params_mapping( |
| 390 | + ckpt_gate_proj_name="gate_proj", |
| 391 | + ckpt_down_proj_name="down_proj", |
| 392 | + ckpt_up_proj_name="up_proj", |
| 393 | + num_experts=self.config.num_experts) |
| 394 | + |
| 395 | + params_dict = dict(self.named_parameters(remove_duplicate=False)) |
| 396 | + loaded_params: set[str] = set() |
| 397 | + for name, loaded_weight in weights: |
| 398 | + if self.config.norm_head and "lm_head.weight" in name: |
| 399 | + loaded_weight = F.normalize(loaded_weight, |
| 400 | + dim=0, |
| 401 | + p=2, |
| 402 | + eps=1e-7) |
| 403 | + |
| 404 | + for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| 405 | + if weight_name not in name: |
| 406 | + continue |
| 407 | + if "mlp.experts" in name: |
| 408 | + continue |
| 409 | + name = name.replace(weight_name, param_name) |
| 410 | + # Skip loading extra bias for GPTQ models. |
| 411 | + if name.endswith(".bias") and name not in params_dict: |
| 412 | + continue |
| 413 | + if name not in params_dict: |
| 414 | + continue |
| 415 | + |
| 416 | + if is_pp_missing_parameter(name, self): |
| 417 | + continue |
| 418 | + |
| 419 | + param = params_dict[name] |
| 420 | + weight_loader = param.weight_loader |
| 421 | + weight_loader(param, loaded_weight, shard_id) |
| 422 | + break |
| 423 | + else: |
| 424 | + for mapping in expert_params_mapping: |
| 425 | + param_name, weight_name, expert_id, shard_id = mapping |
| 426 | + if weight_name not in name: |
| 427 | + continue |
| 428 | + name = name.replace(weight_name, param_name) |
| 429 | + |
| 430 | + if is_pp_missing_parameter(name, self): |
| 431 | + continue |
| 432 | + param = params_dict[name] |
| 433 | + weight_loader = param.weight_loader |
| 434 | + weight_loader(param, |
| 435 | + loaded_weight, |
| 436 | + name, |
| 437 | + shard_id=shard_id, |
| 438 | + expert_id=expert_id) |
| 439 | + break |
| 440 | + else: |
| 441 | + if name.endswith(".bias") and name not in params_dict: |
| 442 | + continue |
| 443 | + if name not in params_dict: |
| 444 | + continue |
| 445 | + |
| 446 | + if is_pp_missing_parameter(name, self): |
| 447 | + continue |
| 448 | + |
| 449 | + param = params_dict[name] |
| 450 | + weight_loader = getattr(param, "weight_loader", |
| 451 | + default_weight_loader) |
| 452 | + weight_loader(param, loaded_weight) |
| 453 | + loaded_params.add(name) |
| 454 | + return loaded_params |
380 | 455 |
|
381 | 456 |
|
382 | 457 | class BailingMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
@@ -463,78 +538,10 @@ def sample(
|
463 | 538 |
|
464 | 539 | def load_weights(self, weights: Iterable[tuple[str,
|
465 | 540 | torch.Tensor]]) -> set[str]:
|
466 |
| - stacked_params_mapping = [ |
467 |
| - # (param_name, shard_name, shard_id) |
468 |
| - ("gate_up_proj", "gate_proj", 0), |
469 |
| - ("gate_up_proj", "up_proj", 1), |
470 |
| - ] |
471 |
| - expert_params_mapping = FusedMoE.make_expert_params_mapping( |
472 |
| - ckpt_gate_proj_name="gate_proj", |
473 |
| - ckpt_down_proj_name="down_proj", |
474 |
| - ckpt_up_proj_name="up_proj", |
475 |
| - num_experts=self.config.num_experts) |
476 |
| - |
477 |
| - params_dict = dict(self.named_parameters(remove_duplicate=False)) |
478 |
| - loaded_params: set[str] = set() |
479 |
| - for name, loaded_weight in weights: |
480 |
| - if (("v_head" in name) or ("inv_freq" in name) or |
481 |
| - (self.config.tie_word_embeddings and "lm_head" in name)): |
482 |
| - continue |
483 |
| - if self.config.norm_head and "lm_head.weight" in name: |
484 |
| - import torch.nn.functional as F |
485 |
| - loaded_weight = F.normalize(loaded_weight, |
486 |
| - dim=0, |
487 |
| - p=2, |
488 |
| - eps=1e-7) |
489 |
| - |
490 |
| - for (param_name, weight_name, shard_id) in stacked_params_mapping: |
491 |
| - if weight_name not in name: |
492 |
| - continue |
493 |
| - if "mlp.experts" in name: |
494 |
| - continue |
495 |
| - name = name.replace(weight_name, param_name) |
496 |
| - # Skip loading extra bias for GPTQ models. |
497 |
| - if name.endswith(".bias") and name not in params_dict: |
498 |
| - continue |
499 |
| - if name not in params_dict: |
500 |
| - continue |
501 |
| - |
502 |
| - if is_pp_missing_parameter(name, self): |
503 |
| - continue |
504 |
| - |
505 |
| - param = params_dict[name] |
506 |
| - weight_loader = param.weight_loader |
507 |
| - weight_loader(param, loaded_weight, shard_id) |
508 |
| - break |
509 |
| - else: |
510 |
| - for mapping in expert_params_mapping: |
511 |
| - param_name, weight_name, expert_id, shard_id = mapping |
512 |
| - if weight_name not in name: |
513 |
| - continue |
514 |
| - name = name.replace(weight_name, param_name) |
515 |
| - |
516 |
| - if is_pp_missing_parameter(name, self): |
517 |
| - continue |
518 |
| - param = params_dict[name] |
519 |
| - weight_loader = param.weight_loader |
520 |
| - weight_loader(param, |
521 |
| - loaded_weight, |
522 |
| - name, |
523 |
| - shard_id=shard_id, |
524 |
| - expert_id=expert_id) |
525 |
| - break |
526 |
| - else: |
527 |
| - if name.endswith(".bias") and name not in params_dict: |
528 |
| - continue |
529 |
| - if name not in params_dict: |
530 |
| - continue |
531 |
| - |
532 |
| - if is_pp_missing_parameter(name, self): |
533 |
| - continue |
| 541 | + loader = AutoWeightsLoader( |
| 542 | + self, |
| 543 | + skip_prefixes=(["lm_head."] |
| 544 | + if self.config.tie_word_embeddings else None), |
| 545 | + ) |
| 546 | + return loader.load_weights(weights) |
534 | 547 |
|
535 |
| - param = params_dict[name] |
536 |
| - weight_loader = getattr(param, "weight_loader", |
537 |
| - default_weight_loader) |
538 |
| - weight_loader(param, loaded_weight) |
539 |
| - loaded_params.add(name) |
540 |
| - return loaded_params |
|
0 commit comments