@@ -66,9 +66,13 @@ def __init__(self, config: "LlamaConfig"):
6666 _ESTIMATE_FUNC = {
6767 "llama" : self ._estimate_llama_flops ,
6868 "qwen2" : self ._estimate_llama_flops ,
69+ "qwen2_moe" : self ._estimate_qwen2_moe_flops ,
6970 "qwen2_vl" : self ._estimate_llama_flops ,
7071 "qwen2_5_vl" : self ._estimate_llama_flops ,
7172 "qwen3" : self ._estimate_llama_flops ,
73+ "qwen3_vl" : self ._estimate_llama_flops ,
74+ "qwen3_moe" : self ._estimate_qwen2_moe_flops ,
75+ "qwen3_vl_moe" : self ._estimate_qwen2_moe_flops ,
7276 }
7377
7478 if config .model_type not in _ESTIMATE_FUNC :
@@ -115,6 +119,44 @@ def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta
115119 flops_achieved = flops_all_token * (1.0 / delta_time ) / 1e12
116120 return flops_achieved
117121
122+ def _estimate_qwen2_moe_flops (self , tokens_sum : int , batch_seqlens : List [int ], delta_time : float ) -> float :
123+ config = self .config .text_config if hasattr (self .config , "text_config" ) else self .config
124+ hidden_size = config .hidden_size
125+ vocab_size = config .vocab_size
126+ num_hidden_layers = config .num_hidden_layers
127+ num_key_value_heads = config .num_key_value_heads
128+ num_attention_heads = config .num_attention_heads
129+ moe_intermediate_size = config .moe_intermediate_size
130+ moe_topk = config .num_experts_per_tok
131+ num_experts = config .num_experts
132+
133+ head_dim = getattr (config , "head_dim" , hidden_size // num_attention_heads )
134+ q_size = num_attention_heads * head_dim
135+ k_size = num_key_value_heads * head_dim
136+ v_size = num_key_value_heads * head_dim
137+
138+ # non-attn per layer parm
139+ # gate + moe export
140+ moe_mlp_N = hidden_size * moe_topk * moe_intermediate_size * 3 + hidden_size * num_experts
141+ attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim )
142+ emd_and_lm_head_N = vocab_size * hidden_size * 2
143+ # non-attn all_layer parm
144+ dense_N = (moe_mlp_N + attn_linear_N ) * num_hidden_layers + emd_and_lm_head_N
145+ # non-attn all_layer & all_token fwd & bwd flops
146+ dense_N_flops = 6 * dense_N * tokens_sum
147+
148+ # attn all_layer & all_token fwd & bwd flops
149+ seqlen_square_sum = 0
150+ for seqlen in batch_seqlens :
151+ seqlen_square_sum += seqlen * seqlen
152+
153+ attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
154+
155+ # all_layer & all_token fwd & bwd flops
156+ flops_all_token = dense_N_flops + attn_qkv_flops
157+ flops_achieved = flops_all_token * (1.0 / delta_time ) / 1e12
158+ return flops_achieved
159+
118160 def estimate_flops (self , batch_seqlens : List [int ], delta_time : float ) -> Tuple [float , float ]:
119161 """
120162 Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.
0 commit comments