Skip to content

Commit d56730a

Browse files
committed
[Minor] Cleanup
1 parent 7adccce commit d56730a

File tree

5 files changed

+5
-22
lines changed

5 files changed

+5
-22
lines changed

sota-check/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ export MUJOCO_GL=egl
2525

2626
conda create -n rl-sota-bench python=3.10 -y
2727
conda install anaconda::libglu -y
28-
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cu121
28+
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
2929
pip3 install "gymnasium[accept-rom-license,atari,mujoco]" vmas tqdm wandb pygame moviepy imageio submitit hydra-core transformers
3030

3131
cd /path/to/tensordict

test/test_env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2549,7 +2549,8 @@ def _step(self, tensordict):
25492549
"reward": action.sum().unsqueeze(0),
25502550
**self.full_done_spec.zero(),
25512551
"observation": obs,
2552-
}
2552+
},
2553+
batch_size=[],
25532554
)
25542555

25552556
torch.manual_seed(0)

torchrl/envs/batched_envs.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -433,9 +433,6 @@ def _check_for_empty_spec(specs: CompositeSpec):
433433
def map_device(key, value, device_map=device_map):
434434
return value.to(device_map[key])
435435

436-
# self._env_tensordict.named_apply(
437-
# map_device, nested_keys=True, filter_empty=True
438-
# )
439436
self._env_tensordict.named_apply(
440437
map_device,
441438
nested_keys=True,
@@ -809,11 +806,6 @@ def select_and_clone(name, tensor):
809806
if name in selected_output_keys:
810807
return tensor.clone()
811808

812-
# out = self.shared_tensordict_parent.named_apply(
813-
# select_and_clone,
814-
# nested_keys=True,
815-
# filter_empty=True,
816-
# )
817809
out = self.shared_tensordict_parent.named_apply(
818810
select_and_clone,
819811
nested_keys=True,
@@ -1208,14 +1200,12 @@ def step_and_maybe_reset(
12081200
if x.device != device
12091201
else x.clone(),
12101202
device=device,
1211-
# filter_empty=True,
12121203
)
12131204
tensordict_ = tensordict_._fast_apply(
12141205
lambda x: x.to(device, non_blocking=self.non_blocking)
12151206
if x.device != device
12161207
else x.clone(),
12171208
device=device,
1218-
# filter_empty=True,
12191209
)
12201210
else:
12211211
next_td = next_td.clone().clear_device_()
@@ -1271,7 +1261,6 @@ def select_and_clone(name, tensor):
12711261
out = next_td.named_apply(
12721262
select_and_clone,
12731263
nested_keys=True,
1274-
# filter_empty=True,
12751264
)
12761265
if out.device != device:
12771266
if device is None:
@@ -1357,7 +1346,6 @@ def select_and_clone(name, tensor):
13571346
out = self.shared_tensordict_parent.named_apply(
13581347
select_and_clone,
13591348
nested_keys=True,
1360-
# filter_empty=True,
13611349
)
13621350
del out["next"]
13631351

@@ -1495,7 +1483,6 @@ def _run_worker_pipe_shared_mem(
14951483
def look_for_cuda(tensor, has_cuda=has_cuda):
14961484
has_cuda[0] = has_cuda[0] or tensor.is_cuda
14971485

1498-
# shared_tensordict.apply(look_for_cuda, filter_empty=True)
14991486
shared_tensordict.apply(look_for_cuda)
15001487
has_cuda = has_cuda[0]
15011488
else:
@@ -1685,9 +1672,5 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
16851672
child_pipe.send(("_".join([cmd, "done"]), None))
16861673

16871674

1688-
def _filter_empty(tensordict):
1689-
return tensordict.select(*tensordict.keys(True, True))
1690-
1691-
16921675
# Create an alias for possible imports
16931676
_BatchedEnv = BatchedEnvBase

torchrl/objectives/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@ def _compare_and_expand(param):
252252
return param._apply_nest(
253253
_compare_and_expand,
254254
batch_size=[expand_dim, *param.shape],
255-
filter_empty=False,
256255
call_on_nested=True,
257256
)
258257
if not isinstance(param, nn.Parameter):

torchrl/objectives/ppo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
455455
entropy = dist.entropy()
456456
except NotImplementedError:
457457
x = dist.rsample((self.samples_mc_entropy,))
458-
entropy = -dist.log_prob(x)
458+
entropy = -dist.log_prob(x).mean(0)
459459
return entropy.unsqueeze(-1)
460460

461461
def _log_weight(
@@ -1036,7 +1036,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
10361036
td_out.set("loss_entropy", -self.entropy_coef * entropy.mean())
10371037

10381038
if self.critic_coef:
1039-
loss_critic = self.loss_critic(tensordict)
1039+
loss_critic = self.loss_critic(tensordict_copy)
10401040
td_out.set("loss_critic", loss_critic.mean())
10411041

10421042
return td_out

0 commit comments

Comments
 (0)