Skip to content

Commit ccc31b5

Browse files
author
Vincent Moens
committed
[Quality] Longer warmup for cudagraph within sota implementations
ghstack-source-id: 140ba6e Pull-Request-resolved: #2945
1 parent 7deff86 commit ccc31b5

File tree

11 files changed

+11
-11
lines changed

11 files changed

+11
-11
lines changed

sota-implementations/a2c/a2c_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def update(batch, max_grad_norm=cfg.optim.max_grad_norm):
183183
storing_device=device,
184184
policy_device=device,
185185
compile_policy={"mode": compile_mode} if cfg.compile.compile else False,
186-
cudagraph_policy=cfg.compile.cudagraphs,
186+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
187187
)
188188

189189
# Main loop

sota-implementations/a2c/a2c_mujoco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def update(batch):
167167
max_frames_per_traj=-1,
168168
trust_policy=True,
169169
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
170-
cudagraph_policy=cfg.compile.cudagraphs,
170+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
171171
)
172172

173173
test_env.eval()

sota-implementations/discrete_sac/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def make_collector(
139139
device=device,
140140
storing_device="cpu",
141141
compile_policy=False if not compile else {"mode": compile_mode},
142-
cudagraph_policy=cudagraphs,
142+
cudagraph_policy={"warmup": 10} if cudagraphs else False,
143143
)
144144
collector.set_seed(cfg.env.seed)
145145
return collector

sota-implementations/dqn/dqn_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def update(sampled_tensordict):
178178
compile_policy={"mode": compile_mode, "fullgraph": True}
179179
if compile_mode is not None
180180
else False,
181-
cudagraph_policy=cfg.compile.cudagraphs,
181+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
182182
)
183183

184184
# Main loop

sota-implementations/dqn/dqn_cartpole.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def update(sampled_tensordict):
136136
compile_policy={"mode": compile_mode, "fullgraph": True}
137137
if compile_mode is not None
138138
else False,
139-
cudagraph_policy=cfg.compile.cudagraphs,
139+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
140140
)
141141

142142
# Main loop

sota-implementations/gail/gail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def main(cfg: DictConfig): # noqa: F821
138138
device=device,
139139
max_frames_per_traj=-1,
140140
compile_policy={"mode": compile_mode} if compile_mode is not None else False,
141-
cudagraph_policy=cfg.compile.cudagraphs,
141+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
142142
)
143143

144144
# Create replay buffer

sota-implementations/iql/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def make_collector(cfg, train_env, actor_model_explore, compile_mode):
138138
total_frames=cfg.collector.total_frames,
139139
device=device,
140140
compile_policy={"mode": compile_mode} if compile_mode else False,
141-
cudagraph_policy=cfg.compile.cudagraphs,
141+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
142142
)
143143
collector.set_seed(cfg.env.seed)
144144
return collector

sota-implementations/ppo/ppo_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def main(cfg: DictConfig): # noqa: F821
8080
device=device,
8181
max_frames_per_traj=-1,
8282
compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False,
83-
cudagraph_policy=cfg.compile.cudagraphs,
83+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
8484
)
8585

8686
# Create data buffer

sota-implementations/ppo/ppo_mujoco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main(cfg: DictConfig): # noqa: F821
7373
device=device,
7474
max_frames_per_traj=-1,
7575
compile_policy={"mode": compile_mode, "warmup": 1} if compile_mode else False,
76-
cudagraph_policy=cfg.compile.cudagraphs,
76+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
7777
)
7878

7979
# Create data buffer

sota-implementations/sac/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def make_collector(cfg, train_env, actor_model_explore, compile_mode):
125125
total_frames=cfg.collector.total_frames,
126126
device=device,
127127
compile_policy={"mode": compile_mode} if compile_mode else False,
128-
cudagraph_policy=cfg.compile.cudagraphs,
128+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
129129
)
130130
collector.set_seed(cfg.env.seed)
131131
return collector

sota-implementations/td3/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def make_collector(cfg, train_env, actor_model_explore, compile_mode, device):
125125
reset_at_each_iter=cfg.collector.reset_at_each_iter,
126126
device=collector_device,
127127
compile_policy={"mode": compile_mode} if compile_mode else False,
128-
cudagraph_policy=cfg.compile.cudagraphs,
128+
cudagraph_policy={"warmup": 10} if cfg.compile.cudagraphs else False,
129129
)
130130
collector.set_seed(cfg.env.seed)
131131
return collector

0 commit comments

Comments
 (0)