Skip to content

Commit 922f316

Browse files
authored
[Model] Support HF format of minimax (#20211)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent 5923ab9 commit 922f316

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ def check_available_online(
218218
trust_remote_code=True),
219219
"MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B",
220220
trust_remote_code=True),
221+
"MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf",
222+
min_transformers_version="4.53"),
221223
"MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01",
222224
trust_remote_code=True,
223225
revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501

vllm/model_executor/models/minimax_text_01.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -667,16 +667,24 @@ def __init__(
667667
eps=config.rms_norm_eps)
668668
if config.attention_type == 0:
669669
self.layernorm_attention_alpha = getattr(
670-
config, 'layernorm_linear_attention_alpha', 1)
670+
config, 'layernorm_linear_attention_alpha',
671+
getattr(config, 'linear_attn_alpha_factor', 1))
671672
self.layernorm_attention_beta = getattr(
672-
config, 'layernorm_linear_attention_beta', 1)
673+
config, 'layernorm_linear_attention_beta',
674+
getattr(config, 'linear_attn_beta_factor', 1))
673675
else:
674676
self.layernorm_attention_alpha = getattr(
675-
config, 'layernorm_full_attention_alpha', 1)
677+
config, 'layernorm_full_attention_alpha',
678+
getattr(config, 'full_attn_alpha_factor', 1))
676679
self.layernorm_attention_beta = getattr(
677-
config, 'layernorm_full_attention_beta', 1)
678-
self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1)
679-
self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1)
680+
config, 'layernorm_full_attention_beta',
681+
getattr(config, 'full_attn_beta_factor', 1))
682+
self.layernorm_mlp_alpha = getattr(
683+
config, 'layernorm_mlp_alpha',
684+
getattr(config, 'mlp_alpha_factor', 1))
685+
self.layernorm_mlp_beta = getattr(
686+
config, 'layernorm_mlp_beta', getattr(config, 'mlp_beta_factor',
687+
1))
680688
self.postnorm = getattr(config, 'postnorm', False)
681689
self.shared_moe = False
682690

@@ -794,6 +802,18 @@ def __init__(
794802
self.decoder_attention_types = getattr(
795803
config, "attn_type_list", False) or getattr(
796804
config, "decoder_attention_types", False)
805+
# The HF format uses "layer_types" instead of "attn_type_list"
806+
# where "linear_attention" is 0 and "full_attention" is 1
807+
if not self.decoder_attention_types and hasattr(config, "layer_types"):
808+
self.decoder_attention_types = []
809+
for layer_type in config.layer_types:
810+
if layer_type == "linear_attention":
811+
self.decoder_attention_types.append(0)
812+
elif layer_type == "full_attention":
813+
self.decoder_attention_types.append(1)
814+
else:
815+
raise ValueError(f"Unsupported layer type: {layer_type}")
816+
# Default to full attention
797817
if not self.decoder_attention_types:
798818
self.decoder_attention_types = [1] * config.num_hidden_layers
799819
self.num_layers = config.num_hidden_layers
@@ -1022,8 +1042,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
10221042
else:
10231043
self.lm_head = PPMissingLayer()
10241044
self.lm_head.float()
1025-
flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
1026-
if attn_type == 1)
1045+
flash_layer_count = sum(
1046+
1 for attn_type in self.model.decoder_attention_types
1047+
if attn_type == 1)
10271048
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
10281049
return
10291050

@@ -1085,9 +1106,10 @@ def which_layer(name: str) -> int:
10851106
return None
10861107

10871108
def is_linear_attn_layer(layer_idx: int) -> bool:
1088-
if layer_idx is None or not hasattr(self.config, "attn_type_list"):
1109+
if layer_idx is None or layer_idx >= len(
1110+
self.model.decoder_attention_types):
10891111
return False
1090-
return self.config.attn_type_list[layer_idx] == 0
1112+
return self.model.decoder_attention_types[layer_idx] == 0
10911113

10921114
def is_moe_weight(name: str) -> bool:
10931115
return "block_sparse_moe" in name and not name.endswith(".bias")
@@ -1275,7 +1297,7 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
12751297
for name, loaded_weight in weights:
12761298
weight_at_layer = which_layer(name)
12771299
if weight_at_layer and weight_at_layer >= len(
1278-
self.config.attn_type_list):
1300+
self.model.decoder_attention_types):
12791301
continue
12801302

12811303
if is_layer_norm_weight(name):

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
"AquilaModel": ("llama", "LlamaForCausalLM"),
3535
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
3636
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
37+
"MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
3738
"MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
3839
"MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"),
3940
# baichuan-7b, upper case 'C' in the class name

0 commit comments

Comments
 (0)