Skip to content

拉起7B 8k训练报错:assert num_micro_batches <= len(seq_len_effective)报错 #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
HongyeZhou opened this issue May 15, 2025 · 6 comments

Comments

@HongyeZhou
Copy link

使用默认脚本拉起8卡训练任务,跑完3个step后报错如下,看上去貌似是seq packing时出错,想问下如何解决?

Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/Skywork-OR1/verl/trainer/main_ppo.py", line 147, in
main()
File "/usr/local/lib/python3.10/dist-packages/hydra/main.py", line 94, in decorated_main
_run_hydra(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 394, in _run_hydra
_run_app(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 457, in _run_app
run_and_report(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 223, in run_and_report
raise ex
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 220, in run_and_report
return func()
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 458, in
lambda: hydra.run(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/hydra.py", line 132, in run
_ = ret.return_value
File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 260, in return_value
raise self._return_value
File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 186, in run_job
ret.return_value = task_function(task_cfg)
File "/Skywork-OR1/verl/trainer/main_ppo.py", line 26, in main
run_ppo(config, code_compute_score)
File "/Skywork-OR1/verl/trainer/main_ppo.py", line 46, in run_ppo
ray.get(main_task.remote(config, compute_score))
File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2822, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 930, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AssertionError): ^[[36mray::main_task()^[[39m (pid=42437, ip=172.17.0.2)
File "/Skywork-OR1/verl/trainer/main_ppo.py", line 143, in main_task
trainer.fit()
File "/Skywork-OR1/verl/trainer/ppo/ray_trainer.py", line 1101, in fit
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
File "/Skywork-OR1/verl/single_controller/ray/base.py", line 41, in func
if blocking: output = ray.get(output)
ray.exceptions.RayTaskError(AssertionError): ^[[36mray::WorkerDict.actor_rollout_compute_log_prob()^[[39m (pid=43189, ip=172.17.0.2, actor_id=f1eed443cb5fb73601366e3601000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7fc6c8a81060>)
File "/Skywork-OR1/verl/single_controller/ray/base.py", line 398, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
File "/Skywork-OR1/verl/single_controller/base/decorator.py", line 404, in inner
return func(*args, **kwargs)
File "/Skywork-OR1/verl/workers/fsdp_workers.py", line 499, in compute_log_prob
output = self.actor.compute_log_prob(data=data)
File "/Skywork-OR1/verl/workers/actor/dp_actor.py", line 184, in compute_log_prob
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
File "/Skywork-OR1/verl/utils/seqlen_balancing.py", line 242, in rearrange_micro_batches
assert num_micro_batches <= len(seq_len_effective)
AssertionError

@hejujie
Copy link
Collaborator

hejujie commented May 15, 2025

我好像没见过这个错误;可以提供一下你的config么?

@HongyeZhou
Copy link
Author

@hejujie 好的,我采用的都是main分支的默认代码,没有修改,下面是我的启动脚本:
python3 -m verl.trainer.main_ppo \ algorithm.adv_estimator=grpo \ data.train_files=$train_files \ data.val_files=$test_files \ data.train_batch_size=$ROLLOUT_BATCH_SIZE \ data.val_batch_size=13000 \ data.max_prompt_length=$MAX_PROMPT_LENGTH \ data.max_response_length=$RES_LENGTH \ actor_rollout_ref.model.path=$MODEL_PATH \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.use_dynamic_bsz=True \ actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.entropy_coeff=$ENTROPY_COEFF \ actor_rollout_ref.actor.ppo_mini_batch_size=$PPO_MINI_BATCH \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$MAX_TOKEN_LEN \ actor_rollout_ref.actor.ulysses_sequence_parallel_size=$SP \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ actor_rollout_ref.actor.fsdp_config.param_offload=True \ actor_rollout_ref.actor.fsdp_config.grad_offload=True \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ actor_rollout_ref.actor.adaptive_entropy.enabled=$USE_ADAPTIVE_ENT \ actor_rollout_ref.actor.adaptive_entropy.target_entropy=${TGT_ENTROPY} \ actor_rollout_ref.actor.adaptive_entropy.max_ent_coef=${MAX_ENT_COEF} \ actor_rollout_ref.actor.adaptive_entropy.min_ent_coef=${MIN_ENT_COEF} \ actor_rollout_ref.actor.adaptive_entropy.delta_ent_coef=${DELTA_ENT_COEF} \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ actor_rollout_ref.rollout.tensor_model_parallel_size=$TP \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.temperature=$TRAIN_TEMPERATURE \ actor_rollout_ref.rollout.val_temperature=0.6 \ actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ actor_rollout_ref.rollout.n=$GROUP_SIZE \ actor_rollout_ref.rollout.n_val=$N_VAL_SAMPLES \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ reward_model.reward_manager=yr \ trainer.critic_warmup=0 \ trainer.rejection_sample=True \ trainer.rejection_sample_multiplier=1 \ trainer.logger=['console'] \ trainer.project_name=$PROJECT_NAME \ trainer.experiment_name=$EXP_NAME \ trainer.val_before_train=False \ trainer.n_gpus_per_node=8 \ trainer.nnodes=$WORLD_SIZE \ trainer.save_freq=20 \ trainer.test_freq=20\ trainer.stats_path=$SAVE_STATS_DIR \ trainer.stats_save_freq=1 \ trainer.default_local_dir=$SAVE_DIR \ trainer.default_hdfs_dir=null \ trainer.total_epochs=30 "${@:1}"`

@hejujie
Copy link
Collaborator

hejujie commented May 17, 2025

MAX_TOKEN_LEN

没发现太多异常,可能需要有更多信息 / Debug Print一下东西出来,你可以看看如下的操作是否有帮助:

  1. 关于MAX_TOKEN_LEN和ROLLOUT_BATCH_SIZE和PPO_MINI_BATCH是否进行过改动;
  2. 考虑不要用dynamic batch size的操作,这个对速度影响不太大,可以参照参照这里的信息后进行修改:https://verl.readthedocs.io/en/latest/perf/perf_tuning.html#tuning-for-dynamic-batch-size;
  3. 不确定是否跟rejection sampling有关,如果上述1/2没有帮助,可以考虑关掉rejection sampling;

@HongyeZhou
Copy link
Author

@hejujie 好的谢谢,我试一试

@HongyeZhou
Copy link
Author

@hejujie 你好,我在训练脚本里关闭了dynamic batch size,跑完step3后不再报之前的错误,但报了新的错,请问有什么解决思路吗?

  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/Skywork-OR1/verl/trainer/main_ppo.py", line 147, in <module>
    main()
  File "/usr/local/lib/python3.10/dist-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
  File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
  File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
  File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
  File "/Skywork-OR1/verl/trainer/main_ppo.py", line 26, in main
    run_ppo(config, code_compute_score)
  File "/Skywork-OR1/verl/trainer/main_ppo.py", line 46, in run_ppo
    ray.get(main_task.remote(config, compute_score))
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2822, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 930, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): �[36mray::main_task()�[39m (pid=11772, ip=172.17.0.2)
  File "/Skywork-OR1/verl/trainer/main_ppo.py", line 143, in main_task
    trainer.fit()
  File "/Skywork-OR1/verl/trainer/ppo/ray_trainer.py", line 1151, in fit
    actor_output = self.actor_rollout_wg.update_actor(batch)
  File "/Skywork-OR1/verl/single_controller/ray/base.py", line 41, in func
    if blocking:            output = ray.get(output)
ray.exceptions.RayTaskError(RuntimeError): �[36mray::WorkerDict.actor_rollout_update_actor()�[39m (pid=12797, ip=172.17.0.2, actor_id=39e23822a2fa854c056bfb8401000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f771c704fa0>)
  File "/Skywork-OR1/verl/single_controller/ray/base.py", line 398, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/Skywork-OR1/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
  File "/Skywork-OR1/verl/workers/fsdp_workers.py", line 418, in update_actor
    metrics = self.actor.update_policy(data=data)
  File "/Skywork-OR1/verl/workers/actor/dp_actor.py", line 308, in update_policy
    loss.backward()
  File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 306, in apply
    return user_fn(self, *args)
  File "/usr/local/lib/python3.10/dist-packages/liger_kernel/ops/utils.py", line 40, in wrapper
    return fn(ctx, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/liger_kernel/ops/rms_norm.py", line 354, in backward
    dX, dW = rms_norm_backward(
  File "/usr/local/lib/python3.10/dist-packages/liger_kernel/ops/rms_norm.py", line 280, in rms_norm_backward
    _rms_norm_backward_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 381, in __getattribute__
    self._init_handles()
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 372, in _init_handles
    max_shared = driver.active.utils.get_device_properties(device)["max_shared_mem"]
RuntimeError: Triton Error [CUDA]: unknown error

@HongyeZhou
Copy link
Author

另外还观察到,在训练的时候会在grader中报一些错但不影响继续训练,这个现象是否正常?
‘’‘�[36m(main_task pid=11772)�[0m ERROR:2025-05-19 11:33:17,982:Error during comparison
�[36m(main_task pid=11772)�[0m Traceback (most recent call last):
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/grader.py", line 809, in compare_single_extraction_wrapper
�[36m(main_task pid=11772)�[0m return compare_single_extraction(g, t)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/utils.py", line 51, in wrapper
�[36m(main_task pid=11772)�[0m return func(*args, **kwargs)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/grader.py", line 789, in compare_single_extraction
�[36m(main_task pid=11772)�[0m return sympy_expr_eq(
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/grader.py", line 667, in sympy_expr_eq
�[36m(main_task pid=11772)�[0m return sympy_compare_relational(gold, pred, float_rounding, numeric_precision)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/grader.py", line 344, in sympy_compare_relational
�[36m(main_task pid=11772)�[0m if sympy_solve_and_compare(gold, pred, float_rounding, numeric_precision):
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/grader.py", line 274, in sympy_solve_and_compare
�[36m(main_task pid=11772)�[0m solved_gold = list(ordered(solve(gold, gold.free_symbols)))
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/sympy/solvers/solvers.py", line 948, in solve
�[36m(main_task pid=11772)�[0m fi = fi.rewrite(Add, evaluate=False, deep=False)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/sympy/core/basic.py", line 1981, in rewrite
�[36m(main_task pid=11772)�[0m return self._rewrite(pattern, rule, method, **hints)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/sympy/core/basic.py", line 1993, in _rewrite
�[36m(main_task pid=11772)�[0m rewritten = meth(*args, **hints)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/sympy/core/relational.py", line 661, in _eval_rewrite_as_Add
�[36m(main_task pid=11772)�[0m args = Add.make_args(L) + Add.make_args(-R)
�[36m(main_task pid=11772)�[0m TypeError: bad operand type for unary -: 'FiniteSet'
�[36m(main_task pid=11772)�[0m ERROR:2025-05-19 11:33:18,011:Error during comparison
�[36m(main_task pid=11772)�[0m Traceback (most recent call last):
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/grader.py", line 809, in compare_single_extraction_wrapper
�[36m(main_task pid=11772)�[0m return compare_single_extraction(g, t)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/utils.py", line 51, in wrapper
�[36m(main_task pid=11772)�[0m return func(*args, **kwargs)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/grader.py", line 789, in compare_single_extraction
�[36m(main_task pid=11772)�[0m return sympy_expr_eq(
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/grader.py", line 667, in sympy_expr_eq
�[36m(main_task pid=11772)�[0m return sympy_compare_relational(gold, pred, float_rounding, numeric_precision)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/grader.py", line 344, in sympy_compare_relational
�[36m(main_task pid=11772)�[0m if sympy_solve_and_compare(gold, pred, float_rounding, numeric_precision):
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/math_verify/grader.py", line 274, in sympy_solve_and_compare
�[36m(main_task pid=11772)�[0m solved_gold = list(ordered(solve(gold, gold.free_symbols)))
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/sympy/solvers/solvers.py", line 948, in solve
�[36m(main_task pid=11772)�[0m fi = fi.rewrite(Add, evaluate=False, deep=False)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/sympy/core/basic.py", line 1981, in rewrite
�[36m(main_task pid=11772)�[0m return self._rewrite(pattern, rule, method, **hints)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/sympy/core/basic.py", line 1993, in _rewrite
�[36m(main_task pid=11772)�[0m rewritten = meth(*args, **hints)
�[36m(main_task pid=11772)�[0m File "/usr/local/lib/python3.10/dist-packages/sympy/core/relational.py", line 661, in _eval_rewrite_as_Add
�[36m(main_task pid=11772)�[0m args = Add.make_args(L) + Add.make_args(-R)
�[36m(main_task pid=11772)�[0m TypeError: bad operand type for unary -: 'FiniteSet'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants