Skip to content

Commit da0399d

Browse files
Kuangdd01hiyougagemini-code-assist[bot]
authored
[misc] Update qwen_moe flops counter (#522)
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 31143dc commit da0399d

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

verl/models/monkey_patch.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
1716

1817
from ..utils.py_functional import is_transformers_version_greater_than
1918
from .transformers.flash_attention_utils import flash_attention_forward
20-
from .transformers.qwen2_vl import qwen2_vl_base_forward, qwen2_vl_model_forward
2119

2220

2321
SUPPORTED_MODEL_TYPE = (
@@ -55,6 +53,8 @@ def apply_ulysses_patch(model_type: str) -> None:
5553
)
5654
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLForConditionalGeneration, Qwen2VLModel
5755

56+
from .transformers.qwen2_vl import qwen2_vl_base_forward, qwen2_vl_model_forward
57+
5858
# fix text-image mixed data
5959
Qwen2VLModel.forward = qwen2_vl_base_forward
6060
Qwen2_5_VLModel.forward = qwen2_vl_base_forward
@@ -63,8 +63,16 @@ def apply_ulysses_patch(model_type: str) -> None:
6363
Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_model_forward
6464
elif model_type in QWEN3_VL_MODELS:
6565
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLForConditionalGeneration, Qwen3VLModel
66+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
67+
Qwen3VLMoeForConditionalGeneration,
68+
Qwen3VLMoeModel,
69+
)
70+
71+
from .transformers.qwen3_vl import qwen3_vl_base_forward, qwen3_vl_model_forward
6672

6773
# fix text-image mixed data
68-
Qwen3VLModel.forward = qwen2_vl_base_forward
74+
Qwen3VLModel.forward = qwen3_vl_base_forward
75+
Qwen3VLMoeModel.forward = qwen3_vl_base_forward
6976
# TODO: add linear cross entropy kernels
70-
Qwen3VLForConditionalGeneration.forward = qwen2_vl_model_forward
77+
Qwen3VLForConditionalGeneration.forward = qwen3_vl_model_forward
78+
Qwen3VLMoeForConditionalGeneration.forward = qwen3_vl_model_forward

verl/utils/flops_counter.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,13 @@ def __init__(self, config: "LlamaConfig"):
6666
_ESTIMATE_FUNC = {
6767
"llama": self._estimate_llama_flops,
6868
"qwen2": self._estimate_llama_flops,
69+
"qwen2_moe": self._estimate_qwen2_moe_flops,
6970
"qwen2_vl": self._estimate_llama_flops,
7071
"qwen2_5_vl": self._estimate_llama_flops,
7172
"qwen3": self._estimate_llama_flops,
73+
"qwen3_vl": self._estimate_llama_flops,
74+
"qwen3_moe": self._estimate_qwen2_moe_flops,
75+
"qwen3_vl_moe": self._estimate_qwen2_moe_flops,
7276
}
7377

7478
if config.model_type not in _ESTIMATE_FUNC:
@@ -115,6 +119,44 @@ def _estimate_llama_flops(self, tokens_sum: int, batch_seqlens: List[int], delta
115119
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
116120
return flops_achieved
117121

122+
def _estimate_qwen2_moe_flops(self, tokens_sum: int, batch_seqlens: List[int], delta_time: float) -> float:
123+
config = self.config.text_config if hasattr(self.config, "text_config") else self.config
124+
hidden_size = config.hidden_size
125+
vocab_size = config.vocab_size
126+
num_hidden_layers = config.num_hidden_layers
127+
num_key_value_heads = config.num_key_value_heads
128+
num_attention_heads = config.num_attention_heads
129+
moe_intermediate_size = config.moe_intermediate_size
130+
moe_topk = config.num_experts_per_tok
131+
num_experts = config.num_experts
132+
133+
head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads)
134+
q_size = num_attention_heads * head_dim
135+
k_size = num_key_value_heads * head_dim
136+
v_size = num_key_value_heads * head_dim
137+
138+
# non-attn per layer parm
139+
# gate + moe export
140+
moe_mlp_N = hidden_size * moe_topk * moe_intermediate_size * 3 + hidden_size * num_experts
141+
attn_linear_N = hidden_size * (q_size + k_size + v_size + num_attention_heads * head_dim)
142+
emd_and_lm_head_N = vocab_size * hidden_size * 2
143+
# non-attn all_layer parm
144+
dense_N = (moe_mlp_N + attn_linear_N) * num_hidden_layers + emd_and_lm_head_N
145+
# non-attn all_layer & all_token fwd & bwd flops
146+
dense_N_flops = 6 * dense_N * tokens_sum
147+
148+
# attn all_layer & all_token fwd & bwd flops
149+
seqlen_square_sum = 0
150+
for seqlen in batch_seqlens:
151+
seqlen_square_sum += seqlen * seqlen
152+
153+
attn_qkv_flops = 12 * seqlen_square_sum * head_dim * num_attention_heads * num_hidden_layers
154+
155+
# all_layer & all_token fwd & bwd flops
156+
flops_all_token = dense_N_flops + attn_qkv_flops
157+
flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12
158+
return flops_achieved
159+
118160
def estimate_flops(self, batch_seqlens: List[int], delta_time: float) -> Tuple[float, float]:
119161
"""
120162
Estimate the FLOPS based on the number of valid tokens in the current batch and the time taken.

verl/utils/logger/logger.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def __init__(self, config: dict[str, Any]) -> None:
6969
with open(os.path.join(config["trainer"]["save_checkpoint_path"], "experiment_config.json"), "w") as f:
7070
json.dump(config, f, indent=2)
7171

72+
with open(os.path.join(config["trainer"]["save_checkpoint_path"], "experiment_log.jsonl"), "w") as f:
73+
pass
74+
75+
with open(os.path.join(config["trainer"]["save_checkpoint_path"], "generations.log"), "w") as f:
76+
pass
77+
7278
def log(self, data: dict[str, Any], step: int) -> None:
7379
with open(os.path.join(self.config["trainer"]["save_checkpoint_path"], "experiment_log.jsonl"), "a") as f:
7480
f.write(json.dumps({"step": step, **unflatten_dict(data)}) + "\n")

0 commit comments

Comments
 (0)