Skip to content

Commit 75b07d2

Browse files
authored
[worker] fix grad norm (#423)
1 parent a4a4128 commit 75b07d2

File tree

5 files changed

+83
-45
lines changed

5 files changed

+83
-45
lines changed

verl/protocol.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def from_dict(
312312
current_batch = tensor.shape[:num_batch_dims]
313313
assert batch_size == current_batch, (
314314
f"Not all the tensor in tensors have the same batch size with batch_dims={num_batch_dims}. "
315-
f"Got {pivot_key} has {batch_size}, {key} has {current_batch}"
315+
f"Got {pivot_key} has {batch_size}, {key} has {current_batch}."
316316
)
317317

318318
for key, value in non_tensors.items():
@@ -322,18 +322,19 @@ def from_dict(
322322
tensor_dict = TensorDict(source=tensors, batch_size=batch_size) if tensors else None
323323
return cls(batch=tensor_dict, non_tensor_batch=non_tensors, meta_info=meta_info)
324324

325-
def to(self, device: torch.device) -> "DataProto":
326-
"""move the batch to device
325+
def to(self, device: torch.device, non_blocking: bool = True) -> "DataProto":
326+
"""Move the batch to device
327327
328328
Args:
329-
device (torch.device, str): torch device
329+
device (torch.device): the device to move to.
330+
non_blocking (bool, optional): whether to use non-blocking mode. Defaults to True.
330331
331332
Returns:
332-
DataProto: the current DataProto
333+
DataProto: the current DataProto.
333334
334335
"""
335336
if self.batch is not None:
336-
self.batch = self.batch.to(device)
337+
self.batch = self.batch.to(device, non_blocking=non_blocking)
337338

338339
return self
339340

verl/utils/seqlen_balancing.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import copy
1616
import heapq
1717
from itertools import chain
18-
from typing import List, Tuple
18+
from typing import Dict, List, Optional, Tuple
1919

2020
import torch
2121
from tensordict import TensorDict
@@ -150,7 +150,7 @@ def greedy_partition(seqlen_list: List[int], k_partitions: int, equal_size: bool
150150
return partitions
151151

152152

153-
def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool):
153+
def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, equal_size: bool) -> List[List[int]]:
154154
"""Get order of seq lengths to make partitions balanced, this is
155155
used in balacing sum of seqlength across dp ranks and microbatches.
156156
@@ -161,8 +161,7 @@ def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, eq
161161
resulting number of partitions
162162
equal_size (bool):
163163
if True, number of items in each partitions must be equal.
164-
if False, only consider balancing the sum, each partition can have
165-
variable number of items
164+
if False, only consider balancing the sum, each partition can have variable number of items
166165
167166
Returns:
168167
partitions (List[List[int]]):
@@ -186,14 +185,28 @@ def _check_and_sort_partitions(partitions):
186185
return _check_and_sort_partitions(partitions)
187186

188187

189-
def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix):
190-
# add some metrics of seqlen sum on dp ranks
188+
def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], prefix: str) -> Dict[str, float]:
189+
"""
190+
Calculate and log metrics related to sequence length imbalance before and after partitioning.
191+
192+
Args:
193+
seqlen_list (List[int]): A list of sequence lengths for each item.
194+
partitions (List[List[int]]): A list of partitions, where each inner list contains indices
195+
from seqlen_list assigned to that partition.
196+
prefix (str): A prefix to be added to each metric key in the returned dictionary.
197+
198+
Returns:
199+
dict: A dictionary containing metrics related to sequence length imbalance.
200+
"""
201+
# Get the number of partitions
191202
k_partition = len(partitions)
192203
# assert len(seqlen_list) % k_partition == 0
193204
batch_size = len(seqlen_list) // k_partition
194205
min_sum_seqlen = None
195206
max_sum_seqlen = None
196207
total_sum_seqlen = 0
208+
209+
# Iterate over each batch of sequence lengths
197210
for offset in range(0, len(seqlen_list), batch_size):
198211
cur_sum_seqlen = sum(seqlen_list[offset : offset + batch_size])
199212
if min_sum_seqlen is None or cur_sum_seqlen < min_sum_seqlen:
@@ -206,7 +219,7 @@ def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], pr
206219
for partition in partitions:
207220
cur_sum_seqlen_balanced = sum([seqlen_list[i] for i in partition])
208221
balanced_sum_seqlen_list.append(cur_sum_seqlen_balanced)
209-
# print("balanced_sum_seqlen_list: ", balanced_sum_seqlen_list)
222+
210223
min_sum_seqlen_balanced = min(balanced_sum_seqlen_list)
211224
max_sum_seqlen_balanced = max(balanced_sum_seqlen_list)
212225

@@ -220,11 +233,13 @@ def log_seqlen_unbalance(seqlen_list: List[int], partitions: List[List[int]], pr
220233
}
221234

222235

223-
def ceildiv(a, b):
236+
def ceildiv(a: float, b: float) -> float:
224237
return -(a // -b)
225238

226239

227-
def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None):
240+
def rearrange_micro_batches(
241+
batch: TensorDict, max_token_len: int, dp_group: Optional[dist.ProcessGroup] = None
242+
) -> Tuple[List[TensorDict], List[List[int]]]:
228243
"""Split the batch into a list of micro_batches, where the max_token_len is smaller than max_token_len
229244
and the number of valid tokens in each micro batch is well balanced.
230245
"""
@@ -253,7 +268,16 @@ def rearrange_micro_batches(batch: TensorDict, max_token_len, dp_group=None):
253268
return micro_batches, micro_bsz_idx
254269

255270

256-
def get_reverse_idx(idx_map):
271+
def get_reverse_idx(idx_map: List[int]) -> List[int]:
272+
"""
273+
Build the inverse of an index mapping.
274+
275+
Args:
276+
idx_map (Sequence[int]): Sequence where idx_map[i] = j.
277+
278+
Returns:
279+
List[int]: Inverse mapping list such that output[j] = i for each i.
280+
"""
257281
reverse_idx_map = copy.deepcopy(idx_map)
258282

259283
for i, idx in enumerate(idx_map):
@@ -263,20 +287,38 @@ def get_reverse_idx(idx_map):
263287

264288

265289
def prepare_dynamic_batch(data: DataProto, max_token_len: int) -> tuple[list[DataProto], list[list[int]]]:
290+
"""
291+
Prepare a batch for dynamic batching.
292+
293+
Args:
294+
data (DataProto): The input data.
295+
max_token_len (int): The maximum token length for dynamic batching.
296+
297+
Returns:
298+
Tuple[List[DataProto], List[List[int]]]: A tuple containing a list of DataProto objects
299+
and a list of index lists.
300+
"""
266301
batch, batch_idx_list = rearrange_micro_batches(data.batch, max_token_len=max_token_len)
267302
micro_batches = []
268303
for i, batch_idx in enumerate(batch_idx_list):
269304
tensors = dict(batch[i])
270-
non_tensors = {}
271-
for key in data.non_tensor_batch.keys():
272-
non_tensors[key] = [data.non_tensor_batch[key][idx] for idx in batch_idx]
273-
305+
non_tensors = {key: value[batch_idx] for key, value in data.non_tensor_batch.items()}
274306
micro_batches.append(DataProto.from_dict(tensors, non_tensors))
275307

276308
return micro_batches, batch_idx_list
277309

278310

279311
def restore_dynamic_batch(data: torch.Tensor, batch_idx_list: List[List[int]]) -> torch.Tensor:
312+
"""
313+
Restore a batch from dynamic batching.
314+
315+
Args:
316+
data (torch.Tensor): The input data.
317+
batch_idx_list (List[List[int]]): The list of index lists.
318+
319+
Returns:
320+
torch.Tensor: The restored data.
321+
"""
280322
indices = list(chain.from_iterable(batch_idx_list))
281323
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
282324
return data[revert_indices]

verl/workers/actor/dp_actor.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor:
195195
self.actor_module.eval()
196196

197197
temperature = data.meta_info["temperature"]
198-
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
198+
select_keys = ["input_ids", "attention_mask", "position_ids", "responses"]
199199
non_tensor_select_keys = ["multi_modal_inputs"]
200200

201201
data = data.select(select_keys, non_tensor_select_keys)
@@ -225,7 +225,7 @@ def update_policy(self, data: DataProto) -> Dict[str, Any]:
225225
self.actor_module.train()
226226

227227
temperature = data.meta_info["temperature"] # temperature must be in the data.meta_info to avoid slient error
228-
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
228+
select_keys = ["input_ids", "attention_mask", "position_ids", "responses", "response_mask"]
229229
select_keys.extend(["old_log_probs", "ref_log_probs", "advantages"])
230230
non_tensor_select_keys = ["multi_modal_inputs"]
231231

@@ -239,10 +239,8 @@ def update_policy(self, data: DataProto) -> Dict[str, Any]:
239239
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=1)
240240

241241
for mini_batch in mini_batches:
242-
response_length = mini_batch.batch["responses"].size(-1)
243-
response_mask = mini_batch.batch["attention_mask"][:, -response_length:]
244-
total_response_tokens = torch.sum(response_mask)
245-
dist.all_reduce(torch.sum(response_mask), op=dist.ReduceOp.SUM)
242+
total_response_tokens = torch.sum(mini_batch.batch["response_mask"])
243+
dist.all_reduce(total_response_tokens, op=dist.ReduceOp.SUM)
246244

247245
if self.config.dynamic_batching:
248246
max_input_len = mini_batch.batch["input_ids"].size(-1)
@@ -256,8 +254,7 @@ def update_policy(self, data: DataProto) -> Dict[str, Any]:
256254

257255
for micro_batch in micro_batches:
258256
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
259-
response_length = model_inputs["responses"].size(-1)
260-
response_mask = model_inputs["attention_mask"][:, -response_length:]
257+
response_mask = model_inputs["response_mask"]
261258
old_log_probs = model_inputs["old_log_probs"]
262259
advantages = model_inputs["advantages"]
263260

verl/workers/critic/dp_critic.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _optimizer_step(self) -> torch.Tensor:
148148
def compute_values(self, data: DataProto) -> torch.Tensor:
149149
self.critic_module.eval()
150150

151-
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
151+
select_keys = ["input_ids", "attention_mask", "position_ids", "responses", "response_mask"]
152152
non_tensor_select_keys = ["multi_modal_inputs"]
153153

154154
data = data.select(select_keys, non_tensor_select_keys)
@@ -172,14 +172,14 @@ def compute_values(self, data: DataProto) -> torch.Tensor:
172172
if self.config.dynamic_batching:
173173
values = restore_dynamic_batch(values, batch_idx_list)
174174

175-
response_length = data.batch["responses"].size(1)
176-
values = values * data.batch["attention_mask"][:, -response_length:] # only action tokens have values
175+
values = values * data.batch["response_mask"] # only action tokens have values
177176
return values
178177

179178
def update_critic(self, data: DataProto) -> Dict[str, Any]:
180179
self.critic_module.train()
181180

182-
select_keys = ["input_ids", "responses", "attention_mask", "position_ids", "values", "returns"]
181+
select_keys = ["input_ids", "attention_mask", "position_ids", "responses", "response_mask"]
182+
select_keys.extend(["values", "returns"])
183183
non_tensor_select_keys = ["multi_modal_inputs"]
184184

185185
# Split to make minibatch iterator for updating the actor
@@ -192,10 +192,8 @@ def update_critic(self, data: DataProto) -> Dict[str, Any]:
192192
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=1)
193193

194194
for mini_batch in mini_batches:
195-
response_length = mini_batch.batch["responses"].size(-1)
196-
response_mask = mini_batch.batch["attention_mask"][:, -response_length:]
197-
total_response_tokens = torch.sum(response_mask)
198-
dist.all_reduce(torch.sum(response_mask), op=dist.ReduceOp.SUM)
195+
total_response_tokens = torch.sum(mini_batch.batch["response_mask"])
196+
dist.all_reduce(total_response_tokens, op=dist.ReduceOp.SUM)
199197

200198
if self.config.dynamic_batching:
201199
max_input_len = mini_batch.batch["input_ids"].size(-1)
@@ -209,8 +207,7 @@ def update_critic(self, data: DataProto) -> Dict[str, Any]:
209207

210208
for micro_batch in micro_batches:
211209
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
212-
response_length = model_inputs["responses"].size(-1)
213-
response_mask = model_inputs["attention_mask"][:, -response_length:]
210+
response_mask = model_inputs["response_mask"]
214211
values = model_inputs["values"]
215212
returns = model_inputs["returns"]
216213

verl/workers/fsdp_workers.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def _process_multi_modal_inputs(self, data: DataProto):
453453
max_pixels = data.meta_info["max_pixels"]
454454
video_fps = data.meta_info["video_fps"]
455455
batch_multi_modal_inputs = []
456-
for multi_modal_data in data.non_tensor_batch["multi_modal_data"]:
456+
for multi_modal_data in data.non_tensor_batch["multi_modal_data"]: # process multi modal data per sample
457457
images, videos = [], []
458458
if "images" in multi_modal_data:
459459
for image in multi_modal_data["images"]:
@@ -468,16 +468,17 @@ def _process_multi_modal_inputs(self, data: DataProto):
468468
# otherwise the batch features will be converted to dict keys
469469
# see https://github.com/hiyouga/EasyR1/pull/339
470470
multi_modal_inputs = dict(self.processor.image_processor(images=images, return_tensors="pt"))
471-
multi_modal_inputs = {k: v.to(torch.cuda.current_device()) for k, v in multi_modal_inputs.items()}
472-
batch_multi_modal_inputs.append(multi_modal_inputs)
473471
elif len(videos) != 0:
474472
multi_modal_inputs = dict(
475473
self.processor.image_processor(images=None, videos=videos, return_tensors="pt")
476474
)
477-
multi_modal_inputs = {k: v.to(torch.cuda.current_device()) for k, v in multi_modal_inputs.items()}
478-
batch_multi_modal_inputs.append(multi_modal_inputs)
479-
else: # text-only data
480-
batch_multi_modal_inputs.append({})
475+
else:
476+
multi_modal_inputs = {}
477+
478+
multi_modal_inputs = {
479+
k: v.to(torch.cuda.current_device(), non_blocking=True) for k, v in multi_modal_inputs.items()
480+
}
481+
batch_multi_modal_inputs.append(multi_modal_inputs)
481482

482483
self._cache["uid"] = data.non_tensor_batch["uid"]
483484
self._cache["multi_modal_inputs"] = np.array(batch_multi_modal_inputs, dtype=object)

0 commit comments

Comments
 (0)