Description
使用默认脚本拉起8卡训练任务,跑完3个step后报错如下,看上去貌似是seq packing时出错,想问下如何解决?
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/Skywork-OR1/verl/trainer/main_ppo.py", line 147, in
main()
File "/usr/local/lib/python3.10/dist-packages/hydra/main.py", line 94, in decorated_main
_run_hydra(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 394, in _run_hydra
_run_app(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 457, in _run_app
run_and_report(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 223, in run_and_report
raise ex
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 220, in run_and_report
return func()
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 458, in
lambda: hydra.run(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/hydra.py", line 132, in run
_ = ret.return_value
File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 260, in return_value
raise self._return_value
File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 186, in run_job
ret.return_value = task_function(task_cfg)
File "/Skywork-OR1/verl/trainer/main_ppo.py", line 26, in main
run_ppo(config, code_compute_score)
File "/Skywork-OR1/verl/trainer/main_ppo.py", line 46, in run_ppo
ray.get(main_task.remote(config, compute_score))
File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2822, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 930, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AssertionError): ^[[36mray::main_task()^[[39m (pid=42437, ip=172.17.0.2)
File "/Skywork-OR1/verl/trainer/main_ppo.py", line 143, in main_task
trainer.fit()
File "/Skywork-OR1/verl/trainer/ppo/ray_trainer.py", line 1101, in fit
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
File "/Skywork-OR1/verl/single_controller/ray/base.py", line 41, in func
if blocking: output = ray.get(output)
ray.exceptions.RayTaskError(AssertionError): ^[[36mray::WorkerDict.actor_rollout_compute_log_prob()^[[39m (pid=43189, ip=172.17.0.2, actor_id=f1eed443cb5fb73601366e3601000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7fc6c8a81060>)
File "/Skywork-OR1/verl/single_controller/ray/base.py", line 398, in func
return getattr(self.worker_dict[key], name)(*args, **kwargs)
File "/Skywork-OR1/verl/single_controller/base/decorator.py", line 404, in inner
return func(*args, **kwargs)
File "/Skywork-OR1/verl/workers/fsdp_workers.py", line 499, in compute_log_prob
output = self.actor.compute_log_prob(data=data)
File "/Skywork-OR1/verl/workers/actor/dp_actor.py", line 184, in compute_log_prob
micro_batches, indices = rearrange_micro_batches(batch=batch, max_token_len=max_token_len)
File "/Skywork-OR1/verl/utils/seqlen_balancing.py", line 242, in rearrange_micro_batches
assert num_micro_batches <= len(seq_len_effective)
AssertionError