Skip to content

Commit 95b1bfe

Browse files
author
Vincent Moens
authored
[Refactor] Use wait instead of is_set to get results in ParallelEnv (#1562)
1 parent 29f42ea commit 95b1bfe

File tree

1 file changed

+8
-17
lines changed

1 file changed

+8
-17
lines changed

torchrl/envs/batched_envs.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -793,14 +793,10 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
793793
for i in range(self.num_workers):
794794
self.parent_channels[i].send(("step", None))
795795

796-
completed = set()
797-
while len(completed) < self.num_workers:
798-
for i, event in enumerate(self._events):
799-
if i in completed:
800-
continue
801-
if event.is_set():
802-
completed.add(i)
803-
event.clear()
796+
for i in range(self.num_workers):
797+
event = self._events[i]
798+
event.wait()
799+
event.clear()
804800

805801
# We must pass a clone of the tensordict, as the values of this tensordict
806802
# will be modified in-place at further steps
@@ -863,15 +859,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
863859
channel.send(out)
864860
workers.append(i)
865861

866-
completed = set()
867-
while len(completed) < len(workers):
868-
for i in workers:
869-
event = self._events[i]
870-
if i in completed:
871-
continue
872-
if event.is_set():
873-
completed.add(i)
874-
event.clear()
862+
for i in workers:
863+
event = self._events[i]
864+
event.wait()
865+
event.clear()
875866

876867
if self._single_task:
877868
# select + clone creates 2 tds, but we can create one only

0 commit comments

Comments
 (0)