Skip to content

Commit b247526

Browse files
author
Vincent Moens
committed
[Deprecation] Enact deprecations
ghstack-source-id: 690a9f6 Pull Request resolved: #2917
1 parent 4162db6 commit b247526

File tree

17 files changed

+151
-280
lines changed

17 files changed

+151
-280
lines changed

test/mocking_classes.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,7 @@ def forward(self, observation, action):
10451045
class CountingEnvCountPolicy(TensorDictModuleBase):
10461046
def __init__(self, action_spec: TensorSpec, action_key: NestedKey = "action"):
10471047
super().__init__()
1048+
assert not isinstance(action_spec, Composite)
10481049
self.action_spec = action_spec
10491050
self.action_key = action_key
10501051
self.in_keys = []
@@ -1411,11 +1412,13 @@ def __init__(
14111412
},
14121413
shape=self.batch_size,
14131414
)
1414-
self.action_spec = Composite(
1415+
action_spec = self.full_action_spec[self.action_key]
1416+
assert not isinstance(action_spec, Composite)
1417+
self.full_action_spec = Composite(
14151418
{
14161419
"data": Composite(
14171420
{
1418-
"action": self.action_spec.unsqueeze(-1).expand(
1421+
"action": action_spec.unsqueeze(-1).expand(
14191422
*self.batch_size, self.nested_dim, 1
14201423
)
14211424
},

test/test_actors.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ def test_nested_keys(self, nested_action, batch_size, nested_dim=5):
218218
nest_obs_action=nested_action, batch_size=batch_size, nested_dim=nested_dim
219219
)
220220
action_spec = env._input_spec["full_action_spec"]
221-
leaf_action_spec = env.action_spec
221+
if nested_action:
222+
leaf_action_spec = env.full_action_spec[env.action_keys[0]]
223+
else:
224+
leaf_action_spec = env.action_spec
222225

223226
space_str, spec = _process_action_space_spec(None, action_spec)
224227
assert spec == action_spec

test/test_collector.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2311,7 +2311,9 @@ def test_multi_collector_nested_env_consistency(self, seed=1):
23112311
torch.manual_seed(seed)
23122312
env_fn = lambda: TransformedEnv(NestedCountingEnv(), InitTracker())
23132313
env = NestedCountingEnv()
2314-
policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
2314+
policy = CountingEnvCountPolicy(
2315+
env.full_action_spec[env.action_key], env.action_key
2316+
)
23152317

23162318
ccollector = MultiaSyncDataCollector(
23172319
create_env_fn=[env_fn],
@@ -2377,7 +2379,9 @@ def test_collector_nested_env_combinations(
23772379
nest_obs_action=nested_obs_action,
23782380
)
23792381
torch.manual_seed(seed)
2380-
policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
2382+
policy = CountingEnvCountPolicy(
2383+
env.full_action_spec[env.action_key], env.action_key
2384+
)
23812385
ccollector = SyncDataCollector(
23822386
create_env_fn=env,
23832387
policy=policy,
@@ -2404,7 +2408,9 @@ def test_nested_env_dims(self, batch_size, nested_dim=5, frames_per_batch=20):
24042408
env = NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim)
24052409
env_fn = lambda: NestedCountingEnv(batch_size=batch_size, nested_dim=nested_dim)
24062410
torch.manual_seed(0)
2407-
policy = CountingEnvCountPolicy(env.action_spec, env.action_key)
2411+
policy = CountingEnvCountPolicy(
2412+
env.full_action_spec[env.action_key], env.action_key
2413+
)
24082414
policy(env.reset())
24092415
ccollector = SyncDataCollector(
24102416
create_env_fn=env_fn,

test/test_env.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
pytest.mark.filterwarnings(
8888
"ignore:Got multiple backends for torchrl.data.replay_buffers.storages"
8989
),
90+
pytest.mark.filterwarnings("ignore:unclosed file"),
9091
]
9192

9293
gym_version = None
@@ -1241,6 +1242,7 @@ def env_make():
12411242
td_serial = env_serial.rollout(max_steps=50)
12421243
finally:
12431244
env_serial.close(raise_if_closed=False)
1245+
gc.collect()
12441246

12451247
try:
12461248
env_parallel = maybe_fork_ParallelEnv(
@@ -1256,6 +1258,7 @@ def env_make():
12561258
assert_allclose_td(td_serial, td_parallel)
12571259
finally:
12581260
env_parallel.close(raise_if_closed=False)
1261+
gc.collect()
12591262

12601263
@pytest.mark.skipif(not _has_dmc, reason="no dm_control")
12611264
def test_multitask(self, maybe_fork_ParallelEnv):
@@ -2809,18 +2812,23 @@ def test_nested_env(self, envclass):
28092812
else:
28102813
raise NotImplementedError
28112814
reset = env.reset()
2812-
with pytest.warns(
2813-
DeprecationWarning, match="non-trivial"
2814-
) if envclass == "NestedCountingEnv" else contextlib.nullcontext():
2815+
if envclass == "NestedCountingEnv":
2816+
assert isinstance(env.reward_spec, Composite)
2817+
else:
28152818
assert not isinstance(env.reward_spec, Composite)
28162819
for done_key in env.done_keys:
28172820
assert (
28182821
env.full_done_spec[done_key]
28192822
== env.output_spec[("full_done_spec", *_unravel_key_to_tuple(done_key))]
28202823
)
2821-
with pytest.warns(
2822-
DeprecationWarning, match="non-trivial"
2823-
) if envclass == "NestedCountingEnv" else contextlib.nullcontext():
2824+
if envclass == "NestedCountingEnv":
2825+
assert (
2826+
env.full_reward_spec[env.reward_key]
2827+
== env.output_spec[
2828+
("full_reward_spec", *_unravel_key_to_tuple(env.reward_key))
2829+
]
2830+
)
2831+
else:
28242832
assert (
28252833
env.reward_spec
28262834
== env.output_spec[

test/test_exploration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,9 @@ def test_nested(
391391
exploratory_policy = TensorDictSequential(
392392
policy,
393393
OrnsteinUhlenbeckProcessModule(
394-
spec=action_spec, action_key=env.action_key, is_init_key=is_init_key
394+
spec=action_spec.clone(),
395+
action_key=env.action_key,
396+
is_init_key=is_init_key,
395397
).to(device),
396398
)
397399
else:

test/test_libs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2833,7 +2833,7 @@ def test_collector(
28332833

28342834
env = maybe_fork_ParallelEnv(n_workers, env_fun)
28352835

2836-
n_actions_per_agent = env.action_spec.shape[-1]
2836+
n_actions_per_agent = env.full_action_spec[env.action_key].shape[-1]
28372837
n_observations_per_agent = env.observation_spec["agents", "observation"].shape[
28382838
-1
28392839
]
@@ -2845,7 +2845,7 @@ def test_collector(
28452845
),
28462846
in_keys=[("agents", "observation")],
28472847
out_keys=[env.action_key],
2848-
spec=env.action_spec,
2848+
spec=env.full_action_spec[env.action_key],
28492849
safe=True,
28502850
)
28512851
ccollector = SyncDataCollector(
@@ -4207,7 +4207,7 @@ def test_parallel_env(self, maybe_fork_ParallelEnv):
42074207
def test_collector(self):
42084208
env = SMACv2Env(map_name="MMM2", seed=0, categorical_actions=True)
42094209
in_feats = env.observation_spec["agents", "observation"].shape[-1]
4210-
out_feats = env.action_spec.space.n
4210+
out_feats = env.full_action_spec[env.action_key].space.n
42114211

42124212
module = TensorDictModule(
42134213
nn.Linear(in_feats, out_feats),

test/test_loggers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
version.parse(torchvision.__version__).base_version
3232
)
3333
else:
34-
TORCHVISION_VERSION = version.parse("0.0.1").base_version
34+
TORCHVISION_VERSION = version.parse("0.0.1")
3535

3636
if _has_tb:
3737
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

test/test_specs.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -820,12 +820,6 @@ def test_create_composite_nested(shape, device):
820820
class TestLock:
821821
@pytest.mark.parametrize("recurse", [None, True, False])
822822
def test_lock(self, recurse):
823-
catch_warn = (
824-
pytest.warns(DeprecationWarning, match="recurse")
825-
if recurse is None
826-
else contextlib.nullcontext()
827-
)
828-
829823
shape = [3, 4, 5]
830824
spec = Composite(
831825
a=Composite(b=Composite(shape=shape[:3], device="cpu"), shape=shape[:2]),
@@ -834,14 +828,13 @@ def test_lock(self, recurse):
834828
spec["a"] = spec["a"].clone()
835829
spec["a", "b"] = spec["a", "b"].clone()
836830
assert not spec.locked
837-
with catch_warn:
838-
spec.lock_(recurse=recurse)
831+
spec.lock_(recurse=recurse)
839832
assert spec.locked
840833
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
841834
spec["a"] = spec["a"].clone()
842835
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
843836
spec.set("a", spec["a"].clone())
844-
if recurse:
837+
if recurse in (None, True):
845838
assert spec["a"].locked
846839
with pytest.raises(RuntimeError, match="Cannot modify a locked Composite."):
847840
spec["a"].set("b", spec["a", "b"].clone())
@@ -851,8 +844,7 @@ def test_lock(self, recurse):
851844
assert not spec["a"].locked
852845
spec["a", "b"] = spec["a", "b"].clone()
853846
spec["a"].set("b", spec["a", "b"].clone())
854-
with catch_warn:
855-
spec.unlock_(recurse=recurse)
847+
spec.unlock_(recurse=recurse)
856848
spec["a"] = spec["a"].clone()
857849
spec["a", "b"] = spec["a", "b"].clone()
858850
spec["a"].set("b", spec["a", "b"].clone())

0 commit comments

Comments
 (0)