Skip to content

Commit 7a98d6e

Browse files
colin2328facebook-github-bot
authored andcommitted
fix grpo_example (#532)
Summary: Pull Request resolved: #532 fix grpo_actor OSS example change sampling replay buffer to wait (10s) for scorer to populate it. add timeout to trajertory_queue.get endpoint, to allow graceful shutdown of pending callable when stop() is called Reviewed By: dcci Differential Revision: D78298877 fbshipit-source-id: 1278f869704f73b4fc29cc4ddbc551f863e05231
1 parent 1bd4eda commit 7a98d6e

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

examples/grpo_actor.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class ReplayBuffer(Actor):
107107
def __init__(self):
108108
"""Initialize an empty buffer."""
109109
self.storage: List[Tuple[int, TrajectorySlice]] = [] # (version, slice)
110+
self.storage_event = asyncio.Event()
110111

111112
@endpoint
112113
async def put(self, slice: TrajectorySlice) -> None:
@@ -116,12 +117,18 @@ async def put(self, slice: TrajectorySlice) -> None:
116117
slice: The trajectory slice to add
117118
"""
118119
self.storage.append((slice.policy_version, slice))
120+
self.storage_event.set()
121+
122+
async def _wait_for_storage(self):
123+
if not self.storage:
124+
await self.storage_event.wait()
119125

120126
@endpoint
121127
async def sample_from(self, k: int) -> List[TrajectorySlice]:
122128
"""Sample k trajectory slices using weighted sampling.
123129
124130
Items from newer policy versions have higher probability of being selected.
131+
If the buffer is empty, waits for it to be populated with a timeout.
125132
126133
Args:
127134
k: Number of slices to sample
@@ -130,10 +137,12 @@ async def sample_from(self, k: int) -> List[TrajectorySlice]:
130137
List of sampled trajectory slices
131138
132139
Raises:
133-
RuntimeError: If buffer is empty
140+
RuntimeError: If buffer is empty after timeout
134141
"""
135-
if not self.storage:
136-
raise RuntimeError("ReplayBuffer is empty")
142+
try:
143+
await asyncio.wait_for(self._wait_for_storage(), timeout=10.0)
144+
except asyncio.TimeoutError:
145+
raise RuntimeError("Timeout waiting for ReplayBuffer to be populated")
137146

138147
# Extract policy versions and add 1 to ensure all weights are positive
139148
policy_versions = [version + 1 for version, _ in self.storage]
@@ -200,8 +209,14 @@ async def run(self) -> None:
200209
self.running = True
201210
try:
202211
while self.running:
203-
slice_ = await self.trajectory_queue.get.call_one()
204-
await self._score_slice(slice_)
212+
try:
213+
slice_ = await asyncio.wait_for(
214+
self.trajectory_queue.get.call_one(),
215+
timeout=1.0,
216+
)
217+
await self._score_slice(slice_)
218+
except asyncio.TimeoutError:
219+
continue
205220
except Exception as e:
206221
print(f"Scorer event loop error: {e}")
207222
finally:

0 commit comments

Comments
 (0)