Skip to content

Commit b486c66

Browse files
committed
amend
1 parent 8151cf4 commit b486c66

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

sota-implementations/grpo/grpo-async.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ def train(
177177
start_time = time.time()
178178

179179
for step in range(total_steps):
180+
if not collector.is_running():
181+
torchrl_logger.info("Collector stopped, stopping training")
182+
break
180183
pbar.update(1)
181184
pbar.set_description(f"Step {step}, writes: {replay_buffer.write_count}")
182185

torchrl/collectors/collectors.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1358,7 +1358,7 @@ def start(self):
13581358
"""
13591359
if self.replay_buffer is None:
13601360
raise RuntimeError("Replay buffer must be defined for execution.")
1361-
if not hasattr(self, "_thread") or not self._thread.is_alive():
1361+
if not self.is_running():
13621362
self._stop = False
13631363
self._thread = threading.Thread(target=self._run_iterator)
13641364
self._thread.daemon = (
@@ -1371,6 +1371,9 @@ def _run_iterator(self):
13711371
if self._stop:
13721372
return
13731373

1374+
def is_running(self):
1375+
return hasattr(self, "_thread") and self._thread.is_alive()
1376+
13741377
def async_shutdown(
13751378
self, timeout: float | None = None, close_env: bool = True
13761379
) -> None:

torchrl/collectors/llm/ray_collector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ def start(self):
170170
pending_task = self._collector.start.remote()
171171
return ray.get(pending_task)
172172

173+
def is_running(self):
174+
return ray.get(self._collector.is_running.remote())
175+
173176
def shutdown(self):
174177
"""Shuts down the collector."""
175178
pending_task = self._collector.shutdown.remote()

torchrl/objectives/llm/grpo.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,6 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
261261
raise ValueError(
262262
f"advantage and log_weight must have the same number of dimensions, got {advantage.ndim=} and {log_weight.ndim=}"
263263
)
264-
print(f"log_weight: {log_weight.shape}")
265-
print(f"advantage: {advantage.shape}")
266-
print(f"mask: {mask.shape}")
267-
print(f"data: {tensordict}")
268264
gain1 = log_weight.exp() * advantage
269265

270266
log_weight_clip = log_weight.clamp(*self._clip_bounds)
@@ -503,6 +499,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
503499
torchrl_logger.info(f"Computing advantage for {prompt=}")
504500
# Cat is the most robust way to combine the trajs
505501
tds = torch.cat(list(self.queues[prompt]), -1)
502+
del self.queues[prompt]
506503
# Collect rewards
507504
reward = tds.get(self.rewards_key, as_nested_tensor=True)
508505
reward_mean = reward.values().mean()

0 commit comments

Comments
 (0)