Skip to content

Commit 1b7eda1

Browse files
author
Vincent Moens
committed
[Feature] TD3 compatibility with compile
ghstack-source-id: fb94307 Pull Request resolved: #2658
1 parent 87a59fb commit 1b7eda1

File tree

16 files changed

+501
-502
lines changed

16 files changed

+501
-502
lines changed

sota-implementations/cql/cql_online.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,7 @@ def update(sampled_tensordict):
159159
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
160160

161161
init_random_frames = cfg.collector.init_random_frames
162-
num_updates = int(
163-
cfg.collector.env_per_collector
164-
* cfg.collector.frames_per_batch
165-
* cfg.optim.utd_ratio
166-
)
162+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
167163
prb = cfg.replay_buffer.prb
168164
frames_per_batch = cfg.collector.frames_per_batch
169165
evaluation_interval = cfg.logger.log_interval

sota-implementations/cql/discrete_cql_online.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,7 @@ def update(sampled_tensordict):
140140
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
141141

142142
init_random_frames = cfg.collector.init_random_frames
143-
num_updates = int(
144-
cfg.collector.env_per_collector
145-
* cfg.collector.frames_per_batch
146-
* cfg.optim.utd_ratio
147-
)
143+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
148144
prb = cfg.replay_buffer.prb
149145
eval_rollout_steps = cfg.env.max_episode_steps
150146
eval_iter = cfg.logger.eval_iter

sota-implementations/crossq/crossq.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,7 @@ def update(sampled_tensordict: TensorDict, update_actor: bool):
179179
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
180180

181181
init_random_frames = cfg.collector.init_random_frames
182-
num_updates = int(
183-
cfg.collector.env_per_collector
184-
* cfg.collector.frames_per_batch
185-
* cfg.optim.utd_ratio
186-
)
182+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
187183
prb = cfg.replay_buffer.prb
188184
eval_iter = cfg.logger.eval_iter
189185
frames_per_batch = cfg.collector.frames_per_batch

sota-implementations/ddpg/ddpg.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,7 @@ def update(sampled_tensordict):
145145
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
146146

147147
init_random_frames = cfg.collector.init_random_frames
148-
num_updates = int(
149-
cfg.collector.env_per_collector
150-
* cfg.collector.frames_per_batch
151-
* cfg.optim.utd_ratio
152-
)
148+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
153149
prb = cfg.replay_buffer.prb
154150
frames_per_batch = cfg.collector.frames_per_batch
155151
eval_iter = cfg.logger.eval_iter

sota-implementations/discrete_sac/discrete_sac.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,7 @@ def update(sampled_tensordict):
144144
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
145145

146146
init_random_frames = cfg.collector.init_random_frames
147-
num_updates = int(
148-
cfg.collector.env_per_collector
149-
* cfg.collector.frames_per_batch
150-
* cfg.optim.utd_ratio
151-
)
147+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
152148
prb = cfg.replay_buffer.prb
153149
eval_rollout_steps = cfg.env.max_episode_steps
154150
eval_iter = cfg.logger.eval_iter

sota-implementations/iql/discrete_iql.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,7 @@ def update(sampled_tensordict):
148148
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
149149

150150
init_random_frames = cfg.collector.init_random_frames
151-
num_updates = int(
152-
cfg.collector.env_per_collector
153-
* cfg.collector.frames_per_batch
154-
* cfg.optim.utd_ratio
155-
)
151+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
156152
prb = cfg.replay_buffer.prb
157153
eval_iter = cfg.logger.eval_iter
158154
frames_per_batch = cfg.collector.frames_per_batch

sota-implementations/iql/iql_online.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,7 @@ def update(sampled_tensordict):
145145
collected_frames = 0
146146

147147
init_random_frames = cfg.collector.init_random_frames
148-
num_updates = int(
149-
cfg.collector.env_per_collector
150-
* cfg.collector.frames_per_batch
151-
* cfg.optim.utd_ratio
152-
)
148+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
153149
prb = cfg.replay_buffer.prb
154150
eval_iter = cfg.logger.eval_iter
155151
frames_per_batch = cfg.collector.frames_per_batch

sota-implementations/sac/config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ collector:
1313
frames_per_batch: 1000
1414
init_env_steps: 1000
1515
device:
16-
env_per_collector: 1
16+
env_per_collector: 8
1717
reset_at_each_iter: False
1818

1919
# replay buffer

sota-implementations/sac/sac.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,7 @@ def update(sampled_tensordict):
143143
pbar = tqdm.tqdm(total=cfg.collector.total_frames)
144144

145145
init_random_frames = cfg.collector.init_random_frames
146-
num_updates = int(
147-
cfg.collector.env_per_collector
148-
* cfg.collector.frames_per_batch
149-
* cfg.optim.utd_ratio
150-
)
146+
num_updates = int(cfg.collector.frames_per_batch * cfg.optim.utd_ratio)
151147
prb = cfg.replay_buffer.prb
152148
eval_iter = cfg.logger.eval_iter
153149
frames_per_batch = cfg.collector.frames_per_batch

sota-implementations/td3/config.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ collector:
1414
frames_per_batch: 1000
1515
reset_at_each_iter: False
1616
device:
17-
env_per_collector: 1
17+
env_per_collector: 8
1818
num_workers: 1
1919

2020
# replay buffer
@@ -52,3 +52,8 @@ logger:
5252
mode: online
5353
eval_iter: 25000
5454
video: False
55+
56+
compile:
57+
compile: False
58+
compile_mode:
59+
cudagraphs: False

0 commit comments

Comments
 (0)