Skip to content

Commit 87e862a

Browse files
authored
[trainer] support auto resume (#425)
1 parent 5ab4bd3 commit 87e862a

File tree

8 files changed

+104
-45
lines changed

8 files changed

+104
-45
lines changed

examples/config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ worker:
6969
tensor_parallel_size: 2
7070
disable_tqdm: false
7171
val_override_config:
72-
temperature: 1.0
72+
temperature: 0.6
73+
top_p: 0.95
7374
n: 1
7475

7576
ref:
@@ -102,3 +103,4 @@ trainer:
102103
save_model_only: false
103104
save_checkpoint_path: null
104105
load_checkpoint_path: null
106+
find_last_checkpoint: true

tests/test_checkpoint.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import json
17+
import os
18+
import shutil
19+
import uuid
20+
21+
import pytest
22+
23+
from verl.utils.checkpoint import CHECKPOINT_TRACKER, find_latest_ckpt, remove_obsolete_ckpt
24+
25+
26+
@pytest.fixture
27+
def save_checkpoint_path():
28+
ckpt_dir = os.path.join("checkpoints", str(uuid.uuid4()))
29+
os.makedirs(ckpt_dir, exist_ok=True)
30+
yield ckpt_dir
31+
shutil.rmtree(ckpt_dir, ignore_errors=True)
32+
33+
34+
def test_find_latest_ckpt(save_checkpoint_path):
35+
with open(os.path.join(save_checkpoint_path, CHECKPOINT_TRACKER), "w") as f:
36+
json.dump({"last_global_step": 10}, f, ensure_ascii=False, indent=2)
37+
38+
assert find_latest_ckpt(save_checkpoint_path) is None
39+
os.makedirs(os.path.join(save_checkpoint_path, "global_step_10"), exist_ok=True)
40+
assert find_latest_ckpt(save_checkpoint_path) == os.path.join(save_checkpoint_path, "global_step_10")
41+
42+
43+
def test_remove_obsolete_ckpt(save_checkpoint_path):
44+
for step in range(5, 30, 5):
45+
os.makedirs(os.path.join(save_checkpoint_path, f"global_step_{step}"), exist_ok=True)
46+
47+
remove_obsolete_ckpt(save_checkpoint_path, global_step=30, best_global_step=10, save_limit=3)
48+
for step in range(5, 30, 5):
49+
is_exist = step in [10, 25]
50+
assert os.path.exists(os.path.join(save_checkpoint_path, f"global_step_{step}")) == is_exist

verl/trainer/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ class TrainerConfig:
141141
"""save checkpoint path, if not specified, use `checkpoints/project_name/experiment_name`"""
142142
load_checkpoint_path: Optional[str] = None
143143
"""load checkpoint path"""
144+
find_last_checkpoint: bool = True
145+
"""automatically find the last checkpoint in the save checkpoint path to resume training"""
144146

145147
def post_init(self):
146148
if self.save_checkpoint_path is None:

verl/trainer/metrics.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,22 @@ def reduce_metrics(metrics: Dict[str, List[Any]]) -> Dict[str, Any]:
2626

2727
def compute_length_metrics(batch: DataProto) -> Dict[str, Any]:
2828
max_response_length = batch.batch["responses"].size(-1)
29+
max_prompt_length = batch.batch["attention_mask"].size(-1) - max_response_length
2930

30-
prompt_mask = batch.batch["attention_mask"][:, :-max_response_length].bool()
31-
response_mask = batch.batch["attention_mask"][:, -max_response_length:].bool()
32-
33-
max_prompt_length = prompt_mask.size(-1)
34-
prompt_length = prompt_mask.sum(-1).float()
35-
response_length = response_mask.sum(-1).float()
31+
prompt_length = batch.batch["attention_mask"][:, :-max_response_length].sum(-1).float()
32+
response_length = batch.batch["attention_mask"][:, -max_response_length:].sum(-1).float()
3633

3734
return {
3835
# response length
3936
"response_length/mean": torch.mean(response_length).detach().item(),
4037
"response_length/max": torch.max(response_length).detach().item(),
4138
"response_length/min": torch.min(response_length).detach().item(),
42-
"response_length/clip_ratio": torch.mean(torch.eq(response_length, max_response_length).float())
43-
.detach()
44-
.item(),
39+
"response_length/clip_ratio": torch.eq(response_length, max_response_length).float().mean().detach().item(),
4540
# prompt length
4641
"prompt_length/mean": torch.mean(prompt_length).detach().item(),
4742
"prompt_length/max": torch.max(prompt_length).detach().item(),
4843
"prompt_length/min": torch.min(prompt_length).detach().item(),
49-
"prompt_length/clip_ratio": torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
44+
"prompt_length/clip_ratio": torch.eq(prompt_length, max_prompt_length).float().mean().detach().item(),
5045
}
5146

5247

verl/trainer/ray_trainer.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from ..single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
3838
from ..single_controller.ray.base import create_colocated_worker_cls
3939
from ..utils import torch_functional as VF
40-
from ..utils.checkpoint import CHECKPOINT_TRACKER, remove_obsolete_ckpt
40+
from ..utils.checkpoint import CHECKPOINT_TRACKER, find_latest_ckpt, remove_obsolete_ckpt
4141
from ..utils.logger import Tracker
4242
from ..utils.py_functional import convert_dict_to_str, timer
4343
from ..utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance
@@ -342,21 +342,28 @@ def _save_checkpoint(self) -> None:
342342
json.dump(checkpointer_tracker_info, f, ensure_ascii=False, indent=2)
343343

344344
def _load_checkpoint(self) -> None:
345-
if self.config.trainer.load_checkpoint_path is None:
345+
if self.config.trainer.load_checkpoint_path is not None:
346+
load_checkpoint_path = self.config.trainer.load_checkpoint_path
347+
elif self.config.trainer.find_last_checkpoint:
348+
load_checkpoint_path = find_latest_ckpt(self.config.trainer.save_checkpoint_path)
349+
else:
350+
load_checkpoint_path = None
351+
352+
if load_checkpoint_path is None:
346353
return
347354

348-
if "global_step_" not in self.config.trainer.load_checkpoint_path.strip(os.path.sep).split(os.path.sep)[-1]:
355+
if "global_step_" not in load_checkpoint_path.strip(os.path.sep).split(os.path.sep)[-1]:
349356
raise ValueError("`load_checkpoint_path` should end with `global_step_*`.")
350357

351-
print(f"Load from checkpoint: {self.config.trainer.load_checkpoint_path}.")
352-
self.global_step = int(self.config.trainer.load_checkpoint_path.strip(os.path.sep).split("global_step_")[-1])
353-
actor_path = os.path.join(self.config.trainer.load_checkpoint_path, "actor")
358+
print(f"Load from checkpoint: {load_checkpoint_path}.")
359+
self.global_step = int(load_checkpoint_path.strip(os.path.sep).split("global_step_")[-1])
360+
actor_path = os.path.join(load_checkpoint_path, "actor")
354361
self.actor_rollout_ref_wg.load_checkpoint(actor_path)
355362
if self.use_critic:
356-
critic_path = os.path.join(self.config.trainer.load_checkpoint_path, "critic")
363+
critic_path = os.path.join(load_checkpoint_path, "critic")
357364
self.critic_wg.load_checkpoint(critic_path)
358365

359-
dataloader_path = os.path.join(self.config.trainer.load_checkpoint_path, "dataloader.pt")
366+
dataloader_path = os.path.join(load_checkpoint_path, "dataloader.pt")
360367
if os.path.exists(dataloader_path):
361368
dataloader_state_dict = torch.load(dataloader_path, weights_only=False)
362369
self.train_dataloader.load_state_dict(dataloader_state_dict)

verl/utils/checkpoint/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .checkpoint_manager import CHECKPOINT_TRACKER, remove_obsolete_ckpt
15+
from .checkpoint_manager import CHECKPOINT_TRACKER, find_latest_ckpt, remove_obsolete_ckpt
1616

1717

18-
__all__ = ["CHECKPOINT_TRACKER", "remove_obsolete_ckpt"]
18+
__all__ = ["CHECKPOINT_TRACKER", "find_latest_ckpt", "remove_obsolete_ckpt"]

verl/utils/checkpoint/checkpoint_manager.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
import os
1617
import random
1718
import re
@@ -107,39 +108,38 @@ def load_rng_state(rng_state: Dict[str, Any]):
107108
random.setstate(rng_state["random"])
108109

109110

110-
def find_latest_ckpt_path(path: Optional[str] = None, directory_format: str = "global_step_{}") -> Optional[str]:
111-
if path is None:
112-
return None
111+
def get_checkpoint_tracker_filename(root_path: str) -> str:
112+
"""
113+
Tracker file rescords the latest chckpoint during training to restart from.
114+
"""
115+
return os.path.join(root_path, CHECKPOINT_TRACKER)
113116

117+
118+
def find_latest_ckpt(path: str, directory_format: str = "global_step_{}") -> Optional[str]:
119+
"""
120+
Find the latest checkpoint in the save path.
121+
"""
114122
tracker_file = get_checkpoint_tracker_filename(path)
115123
if not os.path.exists(tracker_file):
116-
print("Checkpoint tracker file does not exist: %s", tracker_file)
117124
return None
118125

119126
with open(tracker_file, "rb") as f:
120-
iteration = int(f.read().decode())
127+
checkpointer_tracker_info = json.load(f)
121128

122-
ckpt_path = os.path.join(path, directory_format.format(iteration))
129+
ckpt_path = os.path.join(path, directory_format.format(checkpointer_tracker_info["last_global_step"]))
123130
if not os.path.exists(ckpt_path):
124-
print("Checkpoint does not exist: %s", ckpt_path)
131+
print(f"Checkpoint does not exist: {ckpt_path}")
125132
return None
126133

127-
print("Found checkpoint: %s", ckpt_path)
134+
print(f"Found latest checkpoint: {ckpt_path}, will resume from it. Turn off `find_last_checkpoint` to disable it.")
128135
return ckpt_path
129136

130137

131-
def get_checkpoint_tracker_filename(root_path: str) -> str:
132-
"""
133-
Tracker file rescords the latest chckpoint during training to restart from.
134-
"""
135-
return os.path.join(root_path, CHECKPOINT_TRACKER)
136-
137-
138138
def remove_obsolete_ckpt(
139139
path: str, global_step: int, best_global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"
140140
):
141141
"""
142-
Remove the obsolete checkpoints that exceed the save_limit.
142+
Remove the obsolete checkpoints that exceed the save limit.
143143
"""
144144
if save_limit <= 0 or not os.path.exists(path):
145145
return

verl/workers/reward/function.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,21 @@ def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[
8484
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
8585
reward_metrics = defaultdict(list)
8686
response_ids = data.batch["responses"]
87-
response_length = data.batch["response_mask"].sum(dim=-1)
87+
response_length = torch.sum(data.batch["response_mask"], dim=-1)
8888
for i in range(len(data)):
89-
valid_response_ids = response_ids[i][: response_length[i]]
89+
cur_response_length = int(response_length[i].item()) # avoid tensor indexing error
90+
valid_response_ids = response_ids[i][:cur_response_length]
9091
response_str = self.tokenizer.decode(
9192
valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
9293
)
9394
score = self.reward_fn(
9495
{
9596
"response": response_str,
96-
"response_length": response_length[i],
97+
"response_length": cur_response_length,
9798
"ground_truth": data.non_tensor_batch["ground_truth"][i],
9899
}
99100
)
100-
reward_tensor[i, response_length[i] - 1] = score["overall"]
101+
reward_tensor[i, cur_response_length - 1] = score["overall"]
101102
for key, value in score.items():
102103
reward_metrics[key].append(value)
103104

@@ -110,16 +111,17 @@ class BatchFunctionRewardManager(FunctionRewardManager):
110111
def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[float]]]:
111112
reward_inputs = []
112113
response_ids = data.batch["responses"]
113-
response_length = data.batch["response_mask"].sum(dim=-1)
114+
response_length = torch.sum(data.batch["response_mask"], dim=-1)
114115
for i in range(len(data)):
115-
valid_response_ids = response_ids[i][: response_length[i]]
116+
cur_response_length = int(response_length[i].item()) # avoid tensor indexing error
117+
valid_response_ids = response_ids[i][:cur_response_length]
116118
response_str = self.tokenizer.decode(
117119
valid_response_ids, skip_special_tokens=self.config.skip_special_tokens
118120
)
119121
reward_inputs.append(
120122
{
121123
"response": response_str,
122-
"response_length": response_length[i],
124+
"response_length": cur_response_length,
123125
"ground_truth": data.non_tensor_batch["ground_truth"][i],
124126
}
125127
)
@@ -128,7 +130,8 @@ def compute_reward(self, data: DataProto) -> Tuple[torch.Tensor, Dict[str, List[
128130
reward_tensor = torch.zeros_like(data.batch["responses"], dtype=torch.float32)
129131
reward_metrics = defaultdict(list)
130132
for i, score in enumerate(scores):
131-
reward_tensor[i, response_length[i] - 1] = score["overall"]
133+
cur_response_length = int(response_length[i].item()) # avoid tensor indexing error
134+
reward_tensor[i, cur_response_length - 1] = score["overall"]
132135
for key, value in score.items():
133136
reward_metrics[key].append(value)
134137

0 commit comments

Comments
 (0)