Skip to content

Commit cc9fa73

Browse files
authored
[BugFix] Upgrade tensordict deps (#953)
1 parent 61f1cc4 commit cc9fa73

File tree

5 files changed

+27
-11
lines changed

5 files changed

+27
-11
lines changed

torchrl/collectors/collectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def rollout(self) -> TensorDictBase:
653653
# unlock the output tensordict to allow for new keys to be written
654654
# these will be missed during the sync but at least we won't get an error during the update
655655
is_shared = self._tensordict_out.is_shared()
656-
self._tensordict_out.unlock()
656+
self._tensordict_out.unlock_()
657657
self._tensordict_out[..., j] = self._tensordict
658658
if is_shared:
659659
self._tensordict_out.share_memory_()

torchrl/envs/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,15 +379,15 @@ def step(self, tensordict: TensorDictBase) -> TensorDictBase:
379379
# sanity check
380380
self._assert_tensordict_shape(tensordict)
381381

382-
tensordict.lock() # make sure _step does not modify the tensordict
382+
tensordict.lock_() # make sure _step does not modify the tensordict
383383
tensordict_out = self._step(tensordict)
384384
if tensordict_out is tensordict:
385385
raise RuntimeError(
386386
"EnvBase._step should return outplace changes to the input "
387387
"tensordict. Consider emptying the TensorDict first (e.g. tensordict.empty() or "
388388
"tensordict.select()) inside _step before writing new tensors onto this new instance."
389389
)
390-
tensordict.unlock()
390+
tensordict.unlock_()
391391

392392
obs_keys = self.observation_spec.keys(nested_keys=False)
393393
# we deliberately do not update the input values, but we want to keep track of

torchrl/envs/libs/vmas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def _reset(
238238
device=self.device,
239239
)
240240

241-
if infos is not None:
241+
if agent_info is not None:
242242
agent_td.set("info", agent_info)
243243
if dones is not None:
244244
agent_td.set("done", dones[i])
@@ -277,7 +277,7 @@ def _step(
277277
device=self.device,
278278
)
279279

280-
if infos is not None:
280+
if agent_info is not None:
281281
agent_td.set("info", agent_info)
282282
agent_tds.append(agent_td)
283283

torchrl/envs/transforms/transforms.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,16 @@
3636
from torchvision.transforms.functional import center_crop
3737

3838
try:
39-
from torchvision.transforms.functional import resize
39+
from torchvision.transforms.functional import InterpolationMode, resize
40+
41+
def interpolation_fn(interpolation): # noqa: D103
42+
return InterpolationMode(interpolation)
43+
4044
except ImportError:
45+
46+
def interpolation_fn(interpolation): # noqa: D103
47+
return interpolation
48+
4149
from torchvision.transforms.functional_tensor import resize
4250

4351
_has_tv = True
@@ -65,6 +73,14 @@ def new_fun(self, observation_spec):
6573

6674

6775
def _apply_to_composite_inv(function):
76+
# Changes the input_spec following a transform function.
77+
# The usage is: if an env expects a certain input (e.g. a double tensor)
78+
# but the input has to be transformed (e.g. it is float), this function will
79+
# modify the spec to get a spec that from the outside matches what is given
80+
# (ie a float).
81+
# Now since EnvBase.step ignores new inputs (ie the root level of the
82+
# tensor is not updated) an out_key that does not match the in_key has
83+
# no effect on the spec.
6884
def new_fun(self, input_spec):
6985
if isinstance(input_spec, CompositeSpec):
7086
d = input_spec._specs
@@ -996,7 +1012,7 @@ def __init__(
9961012
super().__init__(in_keys=in_keys, out_keys=out_keys)
9971013
self.w = int(w)
9981014
self.h = int(h)
999-
self.interpolation = interpolation
1015+
self.interpolation = interpolation_fn(interpolation)
10001016

10011017
def _apply_transform(self, observation: torch.Tensor) -> torch.Tensor:
10021018
# flatten if necessary

torchrl/objectives/value/advantages.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import List, Optional, Tuple, Union
88

99
import torch
10-
from tensordict.nn import dispatch_kwargs
10+
from tensordict.nn import dispatch
1111
from tensordict.tensordict import TensorDictBase
1212
from torch import nn, Tensor
1313

@@ -95,7 +95,7 @@ def is_functional(self):
9595
)
9696

9797
@_self_set_grad_enabled
98-
@dispatch_kwargs
98+
@dispatch
9999
def forward(
100100
self,
101101
tensordict: TensorDictBase,
@@ -269,7 +269,7 @@ def is_functional(self):
269269
)
270270

271271
@_self_set_grad_enabled
272-
@dispatch_kwargs
272+
@dispatch
273273
def forward(
274274
self,
275275
tensordict: TensorDictBase,
@@ -461,7 +461,7 @@ def is_functional(self):
461461
)
462462

463463
@_self_set_grad_enabled
464-
@dispatch_kwargs
464+
@dispatch
465465
def forward(
466466
self,
467467
tensordict: TensorDictBase,

0 commit comments

Comments
 (0)