Skip to content

Commit a4c1ee3

Browse files
author
Vincent Moens
committed
[Quality] Better TD construction in codebase
ghstack-source-id: 9e280d9 Pull Request resolved: #2565
1 parent 9f8f77c commit a4c1ee3

File tree

28 files changed

+83
-97
lines changed

28 files changed

+83
-97
lines changed

examples/rlhf/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def log(self, model):
8585

8686
class TrainLogger:
8787
def __init__(self, size: int, log_interval: int, logger: Logger):
88-
self.data = TensorDict({}, [size])
88+
self.data = TensorDict(batch_size=[size])
8989
self.counter = 0
9090
self.log_interval = log_interval
9191
self.logger = logger

sota-implementations/a2c/a2c_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def main(cfg: "DictConfig"): # noqa: F821
144144
}
145145
)
146146

147-
losses = TensorDict({}, batch_size=[num_mini_batches])
147+
losses = TensorDict(batch_size=[num_mini_batches])
148148
training_start = time.time()
149149

150150
# Compute GAE

sota-implementations/a2c/a2c_mujoco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def main(cfg: "DictConfig"): # noqa: F821
128128
}
129129
)
130130

131-
losses = TensorDict({}, batch_size=[num_mini_batches])
131+
losses = TensorDict(batch_size=[num_mini_batches])
132132
training_start = time.time()
133133

134134
# Compute GAE

sota-implementations/cql/cql_online.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def main(cfg: "DictConfig"): # noqa: F821
127127
# optimization steps
128128
training_start = time.time()
129129
if collected_frames >= init_random_frames:
130-
log_loss_td = TensorDict({}, [num_updates])
130+
log_loss_td = TensorDict(batch_size=[num_updates])
131131
for j in range(num_updates):
132132
# sample from replay buffer
133133
sampled_tensordict = replay_buffer.sample()

sota-implementations/impala/impala_multi_node_ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def main(cfg: "DictConfig"): # noqa: F821
184184
logger.log_scalar(key, value, collected_frames)
185185
continue
186186

187-
losses = TensorDict({}, batch_size=[sgd_updates])
187+
losses = TensorDict(batch_size=[sgd_updates])
188188
training_start = time.time()
189189
for j in range(sgd_updates):
190190

sota-implementations/impala/impala_multi_node_submitit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def main(cfg: "DictConfig"): # noqa: F821
176176
logger.log_scalar(key, value, collected_frames)
177177
continue
178178

179-
losses = TensorDict({}, batch_size=[sgd_updates])
179+
losses = TensorDict(batch_size=[sgd_updates])
180180
training_start = time.time()
181181
for j in range(sgd_updates):
182182

sota-implementations/impala/impala_single_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def main(cfg: "DictConfig"): # noqa: F821
154154
logger.log_scalar(key, value, collected_frames)
155155
continue
156156

157-
losses = TensorDict({}, batch_size=[sgd_updates])
157+
losses = TensorDict(batch_size=[sgd_updates])
158158
training_start = time.time()
159159
for j in range(sgd_updates):
160160

sota-implementations/ppo/ppo_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def main(cfg: "DictConfig"): # noqa: F821
138138
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
139139
cfg_optim_max_grad_norm = cfg.optim.max_grad_norm
140140
cfg.loss.clip_epsilon = cfg_loss_clip_epsilon
141-
losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
141+
losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
142142

143143
for i, data in enumerate(collector):
144144

sota-implementations/ppo/ppo_mujoco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def main(cfg: "DictConfig"): # noqa: F821
125125
cfg_loss_clip_epsilon = cfg.loss.clip_epsilon
126126
cfg_logger_test_interval = cfg.logger.test_interval
127127
cfg_logger_num_test_episodes = cfg.logger.num_test_episodes
128-
losses = TensorDict({}, batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
128+
losses = TensorDict(batch_size=[cfg_loss_ppo_epochs, num_mini_batches])
129129

130130
for i, data in enumerate(collector):
131131

sota-implementations/sac/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def main(cfg: "DictConfig"): # noqa: F821
126126
# Optimization steps
127127
training_start = time.time()
128128
if collected_frames >= init_random_frames:
129-
losses = TensorDict({}, batch_size=[num_updates])
129+
losses = TensorDict(batch_size=[num_updates])
130130
for i in range(num_updates):
131131
# Sample from replay buffer
132132
sampled_tensordict = replay_buffer.sample()

0 commit comments

Comments
 (0)