Skip to content

Commit 5501d4a

Browse files
author
Vincent Moens
authored
[Benchmark] Benchmark Gym vs TorchRL (#1602)
1 parent 3018810 commit 5501d4a

File tree

1 file changed

+337
-0
lines changed

1 file changed

+337
-0
lines changed
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""This script executes some envs across the Gym library with the explicit scope of testing the throughput using the various TorchRL components.
7+
8+
We test:
9+
- gym async envs embedded in a TorchRL's GymEnv wrapper,
10+
- ParallelEnv with regular GymEnv instances,
11+
- Data collector
12+
- Multiprocessed data collectors with parallel envs.
13+
14+
The tests are executed with various number of cpus, and on different devices.
15+
16+
"""
17+
import time
18+
19+
import myosuite # noqa: F401
20+
import tqdm
21+
from torchrl._utils import timeit
22+
from torchrl.collectors import (
23+
MultiaSyncDataCollector,
24+
MultiSyncDataCollector,
25+
RandomPolicy,
26+
SyncDataCollector,
27+
)
28+
from torchrl.envs import EnvCreator, GymEnv, ParallelEnv
29+
from torchrl.envs.libs.gym import gym_backend as gym_bc, set_gym_backend
30+
31+
if __name__ == "__main__":
32+
for envname in [
33+
"HalfCheetah-v4",
34+
"CartPole-v1",
35+
"myoHandReachRandom-v0",
36+
"ALE/Breakout-v5",
37+
"CartPole-v1",
38+
]:
39+
# the number of collectors won't affect the resources, just impacts how the envs are split in sub-sub-processes
40+
for num_workers, num_collectors in zip((8, 16, 32, 64), (2, 4, 8, 8)):
41+
with open(
42+
f"atari_{envname}_{num_workers}.txt".replace("/", "-"), "w+"
43+
) as log:
44+
if "myo" in envname:
45+
gym_backend = "gym"
46+
else:
47+
gym_backend = "gymnasium"
48+
49+
total_frames = num_workers * 10_000
50+
51+
# pure gym
52+
def make(envname=envname, gym_backend=gym_backend):
53+
with set_gym_backend(gym_backend):
54+
return gym_bc().make(envname)
55+
56+
with set_gym_backend(gym_backend):
57+
env = gym_bc().vector.AsyncVectorEnv(
58+
[make for _ in range(num_workers)]
59+
)
60+
env.reset()
61+
global_step = 0
62+
times = []
63+
start = time.time()
64+
print("Timer started.")
65+
for _ in tqdm.tqdm(range(total_frames // num_workers)):
66+
env.step(env.action_space.sample())
67+
global_step += num_workers
68+
env.close()
69+
log.write(
70+
f"pure gym: {num_workers * 10_000 / (time.time() - start): 4.4f} fps\n"
71+
)
72+
log.flush()
73+
74+
# regular parallel env
75+
for device in (
76+
"cuda:0",
77+
"cpu",
78+
):
79+
80+
def make(envname=envname, gym_backend=gym_backend, device=device):
81+
with set_gym_backend(gym_backend):
82+
return GymEnv(envname, device=device)
83+
84+
env_make = EnvCreator(make)
85+
penv = ParallelEnv(num_workers, env_make)
86+
# warmup
87+
penv.rollout(2)
88+
pbar = tqdm.tqdm(total=num_workers * 10_000)
89+
t0 = time.time()
90+
for _ in range(100):
91+
data = penv.rollout(100, break_when_any_done=False)
92+
pbar.update(100 * num_workers)
93+
log.write(
94+
f"penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
95+
)
96+
log.flush()
97+
penv.close()
98+
timeit.print()
99+
del penv
100+
101+
for device in ("cuda:0", "cpu"):
102+
103+
def make(envname=envname, gym_backend=gym_backend, device=device):
104+
with set_gym_backend(gym_backend):
105+
return GymEnv(envname, device=device)
106+
107+
env_make = EnvCreator(make)
108+
# penv = SerialEnv(num_workers, env_make)
109+
penv = ParallelEnv(num_workers, env_make)
110+
collector = SyncDataCollector(
111+
penv,
112+
RandomPolicy(penv.action_spec),
113+
frames_per_batch=1024,
114+
total_frames=num_workers * 10_000,
115+
)
116+
pbar = tqdm.tqdm(total=num_workers * 10_000)
117+
total_frames = 0
118+
for i, data in enumerate(collector):
119+
if i == num_collectors:
120+
t0 = time.time()
121+
if i >= num_collectors:
122+
total_frames += data.numel()
123+
pbar.update(data.numel())
124+
pbar.set_description(
125+
f"single collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
126+
)
127+
log.write(
128+
f"single collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
129+
)
130+
log.flush()
131+
collector.shutdown()
132+
del collector
133+
134+
for device in (
135+
"cuda:0",
136+
"cpu",
137+
):
138+
# gym parallel env
139+
def make_env(
140+
envname=envname,
141+
num_workers=num_workers,
142+
gym_backend=gym_backend,
143+
device=device,
144+
):
145+
with set_gym_backend(gym_backend):
146+
penv = GymEnv(envname, num_envs=num_workers, device=device)
147+
return penv
148+
149+
penv = make_env()
150+
# warmup
151+
penv.rollout(2)
152+
pbar = tqdm.tqdm(total=num_workers * 10_000)
153+
t0 = time.time()
154+
for _ in range(100):
155+
data = penv.rollout(100, break_when_any_done=False)
156+
pbar.update(100 * num_workers)
157+
log.write(
158+
f"gym penv {device}: {num_workers * 10_000 / (time.time() - t0): 4.4f} fps\n"
159+
)
160+
log.flush()
161+
penv.close()
162+
del penv
163+
164+
for device in (
165+
"cuda:0",
166+
"cpu",
167+
):
168+
# async collector
169+
# + torchrl parallel env
170+
def make_env(
171+
envname=envname, gym_backend=gym_backend, device=device
172+
):
173+
with set_gym_backend(gym_backend):
174+
return GymEnv(envname, device=device)
175+
176+
penv = ParallelEnv(
177+
num_workers // num_collectors, EnvCreator(make_env)
178+
)
179+
collector = MultiaSyncDataCollector(
180+
[penv] * num_collectors,
181+
policy=RandomPolicy(penv.action_spec),
182+
frames_per_batch=1024,
183+
total_frames=num_workers * 10_000,
184+
device=device,
185+
)
186+
pbar = tqdm.tqdm(total=num_workers * 10_000)
187+
total_frames = 0
188+
for i, data in enumerate(collector):
189+
if i == num_collectors:
190+
t0 = time.time()
191+
if i >= num_collectors:
192+
total_frames += data.numel()
193+
pbar.update(data.numel())
194+
pbar.set_description(
195+
f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
196+
)
197+
log.write(
198+
f"async collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
199+
)
200+
log.flush()
201+
collector.shutdown()
202+
del collector
203+
204+
for device in (
205+
"cuda:0",
206+
"cpu",
207+
):
208+
# async collector
209+
# + gym async env
210+
def make_env(
211+
envname=envname,
212+
num_workers=num_workers,
213+
gym_backend=gym_backend,
214+
device=device,
215+
):
216+
with set_gym_backend(gym_backend):
217+
penv = GymEnv(envname, num_envs=num_workers, device=device)
218+
return penv
219+
220+
penv = EnvCreator(
221+
lambda num_workers=num_workers // num_collectors: make_env(
222+
num_workers
223+
)
224+
)
225+
collector = MultiaSyncDataCollector(
226+
[penv] * num_collectors,
227+
policy=RandomPolicy(penv().action_spec),
228+
frames_per_batch=1024,
229+
total_frames=num_workers * 10_000,
230+
num_sub_threads=num_workers // num_collectors,
231+
device=device,
232+
)
233+
pbar = tqdm.tqdm(total=num_workers * 10_000)
234+
total_frames = 0
235+
for i, data in enumerate(collector):
236+
if i == num_collectors:
237+
t0 = time.time()
238+
if i >= num_collectors:
239+
total_frames += data.numel()
240+
pbar.update(data.numel())
241+
pbar.set_description(
242+
f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps"
243+
)
244+
log.write(
245+
f"async collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
246+
)
247+
log.flush()
248+
collector.shutdown()
249+
del collector
250+
251+
for device in (
252+
"cuda:0",
253+
"cpu",
254+
):
255+
# sync collector
256+
# + torchrl parallel env
257+
def make_env(
258+
envname=envname, gym_backend=gym_backend, device=device
259+
):
260+
with set_gym_backend(gym_backend):
261+
return GymEnv(envname, device=device)
262+
263+
penv = ParallelEnv(
264+
num_workers // num_collectors, EnvCreator(make_env)
265+
)
266+
collector = MultiSyncDataCollector(
267+
[penv] * num_collectors,
268+
policy=RandomPolicy(penv.action_spec),
269+
frames_per_batch=1024,
270+
total_frames=num_workers * 10_000,
271+
device=device,
272+
)
273+
pbar = tqdm.tqdm(total=num_workers * 10_000)
274+
total_frames = 0
275+
for i, data in enumerate(collector):
276+
if i == num_collectors:
277+
t0 = time.time()
278+
if i >= num_collectors:
279+
total_frames += data.numel()
280+
pbar.update(data.numel())
281+
pbar.set_description(
282+
f"collector + torchrl penv: {total_frames / (time.time() - t0): 4.4f} fps"
283+
)
284+
log.write(
285+
f"sync collector + torchrl penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
286+
)
287+
log.flush()
288+
collector.shutdown()
289+
del collector
290+
291+
for device in (
292+
"cuda:0",
293+
"cpu",
294+
):
295+
# sync collector
296+
# + gym async env
297+
def make_env(
298+
envname=envname,
299+
num_workers=num_workers,
300+
gym_backend=gym_backend,
301+
device=device,
302+
):
303+
with set_gym_backend(gym_backend):
304+
penv = GymEnv(envname, num_envs=num_workers, device=device)
305+
return penv
306+
307+
penv = EnvCreator(
308+
lambda num_workers=num_workers // num_collectors: make_env(
309+
num_workers
310+
)
311+
)
312+
collector = MultiSyncDataCollector(
313+
[penv] * num_collectors,
314+
policy=RandomPolicy(penv().action_spec),
315+
frames_per_batch=1024,
316+
total_frames=num_workers * 10_000,
317+
num_sub_threads=num_workers // num_collectors,
318+
device=device,
319+
)
320+
pbar = tqdm.tqdm(total=num_workers * 10_000)
321+
total_frames = 0
322+
for i, data in enumerate(collector):
323+
if i == num_collectors:
324+
t0 = time.time()
325+
if i >= num_collectors:
326+
total_frames += data.numel()
327+
pbar.update(data.numel())
328+
pbar.set_description(
329+
f"{i} collector + gym penv: {total_frames / (time.time() - t0): 4.4f} fps"
330+
)
331+
log.write(
332+
f"sync collector + gym penv {device}: {total_frames / (time.time() - t0): 4.4f} fps\n"
333+
)
334+
log.flush()
335+
collector.shutdown()
336+
del collector
337+
exit()

0 commit comments

Comments
 (0)