@@ -667,16 +667,24 @@ def __init__(
667
667
eps = config .rms_norm_eps )
668
668
if config .attention_type == 0 :
669
669
self .layernorm_attention_alpha = getattr (
670
- config , 'layernorm_linear_attention_alpha' , 1 )
670
+ config , 'layernorm_linear_attention_alpha' ,
671
+ getattr (config , 'linear_attn_alpha_factor' , 1 ))
671
672
self .layernorm_attention_beta = getattr (
672
- config , 'layernorm_linear_attention_beta' , 1 )
673
+ config , 'layernorm_linear_attention_beta' ,
674
+ getattr (config , 'linear_attn_beta_factor' , 1 ))
673
675
else :
674
676
self .layernorm_attention_alpha = getattr (
675
- config , 'layernorm_full_attention_alpha' , 1 )
677
+ config , 'layernorm_full_attention_alpha' ,
678
+ getattr (config , 'full_attn_alpha_factor' , 1 ))
676
679
self .layernorm_attention_beta = getattr (
677
- config , 'layernorm_full_attention_beta' , 1 )
678
- self .layernorm_mlp_alpha = getattr (config , 'layernorm_mlp_alpha' , 1 )
679
- self .layernorm_mlp_beta = getattr (config , 'layernorm_mlp_beta' , 1 )
680
+ config , 'layernorm_full_attention_beta' ,
681
+ getattr (config , 'full_attn_beta_factor' , 1 ))
682
+ self .layernorm_mlp_alpha = getattr (
683
+ config , 'layernorm_mlp_alpha' ,
684
+ getattr (config , 'mlp_alpha_factor' , 1 ))
685
+ self .layernorm_mlp_beta = getattr (
686
+ config , 'layernorm_mlp_beta' , getattr (config , 'mlp_beta_factor' ,
687
+ 1 ))
680
688
self .postnorm = getattr (config , 'postnorm' , False )
681
689
self .shared_moe = False
682
690
@@ -794,6 +802,18 @@ def __init__(
794
802
self .decoder_attention_types = getattr (
795
803
config , "attn_type_list" , False ) or getattr (
796
804
config , "decoder_attention_types" , False )
805
+ # The HF format uses "layer_types" instead of "attn_type_list"
806
+ # where "linear_attention" is 0 and "full_attention" is 1
807
+ if not self .decoder_attention_types and hasattr (config , "layer_types" ):
808
+ self .decoder_attention_types = []
809
+ for layer_type in config .layer_types :
810
+ if layer_type == "linear_attention" :
811
+ self .decoder_attention_types .append (0 )
812
+ elif layer_type == "full_attention" :
813
+ self .decoder_attention_types .append (1 )
814
+ else :
815
+ raise ValueError (f"Unsupported layer type: { layer_type } " )
816
+ # Default to full attention
797
817
if not self .decoder_attention_types :
798
818
self .decoder_attention_types = [1 ] * config .num_hidden_layers
799
819
self .num_layers = config .num_hidden_layers
@@ -1022,8 +1042,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
1022
1042
else :
1023
1043
self .lm_head = PPMissingLayer ()
1024
1044
self .lm_head .float ()
1025
- flash_layer_count = sum (1 for attn_type in self .config .attn_type_list
1026
- if attn_type == 1 )
1045
+ flash_layer_count = sum (
1046
+ 1 for attn_type in self .model .decoder_attention_types
1047
+ if attn_type == 1 )
1027
1048
self .kv_cache = [torch .tensor ([]) for _ in range (flash_layer_count )]
1028
1049
return
1029
1050
@@ -1085,9 +1106,10 @@ def which_layer(name: str) -> int:
1085
1106
return None
1086
1107
1087
1108
def is_linear_attn_layer (layer_idx : int ) -> bool :
1088
- if layer_idx is None or not hasattr (self .config , "attn_type_list" ):
1109
+ if layer_idx is None or layer_idx >= len (
1110
+ self .model .decoder_attention_types ):
1089
1111
return False
1090
- return self .config . attn_type_list [layer_idx ] == 0
1112
+ return self .model . decoder_attention_types [layer_idx ] == 0
1091
1113
1092
1114
def is_moe_weight (name : str ) -> bool :
1093
1115
return "block_sparse_moe" in name and not name .endswith (".bias" )
@@ -1275,7 +1297,7 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor,
1275
1297
for name , loaded_weight in weights :
1276
1298
weight_at_layer = which_layer (name )
1277
1299
if weight_at_layer and weight_at_layer >= len (
1278
- self .config . attn_type_list ):
1300
+ self .model . decoder_attention_types ):
1279
1301
continue
1280
1302
1281
1303
if is_layer_norm_weight (name ):
0 commit comments