Skip to content

Commit 50cb2b4

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent b2cec25 commit 50cb2b4

File tree

3 files changed

+86
-52
lines changed

3 files changed

+86
-52
lines changed

test/test_transforms.py

Lines changed: 80 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13206,28 +13206,7 @@ def test_single_trans_env_check(self):
1320613206
ConditionalPolicySwitch(condition=condition, policy=policy_even),
1320713207
)
1320813208
env = base_env.append_transform(transforms)
13209-
r = env.rollout(1000, policy_odd, break_when_all_done=True)
13210-
assert r.shape[0] == 15
13211-
assert (r["action"] == 0).all()
13212-
assert (
13213-
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
13214-
).all()
13215-
assert r["next", "done"].any()
13216-
13217-
# Player 1
13218-
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
13219-
transforms = Compose(
13220-
StepCounter(),
13221-
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
13222-
)
13223-
env = base_env.append_transform(transforms)
13224-
r = env.rollout(1000, policy_even, break_when_all_done=True)
13225-
assert r.shape[0] == 16
13226-
assert (r["action"] == 1).all()
13227-
assert (
13228-
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
13229-
).all()
13230-
assert r["next", "done"].any()
13209+
env.check_env_specs()
1323113210

1323213211
def _create_policy_odd(self, base_env):
1323313212
return WrapModule(
@@ -13324,43 +13303,95 @@ def make_env(max_count):
1332413303
self._test_env(env, policy_odd)
1332513304

1332613305
def test_transform_no_env(self):
13327-
"""tests the transform on dummy data, without an env."""
13328-
raise NotImplementedError
13306+
policy_odd = lambda td: td
13307+
policy_even = lambda td: td
13308+
condition = lambda td: True
13309+
transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even)
13310+
with pytest.raises(
13311+
RuntimeError,
13312+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13313+
):
13314+
transforms(TensorDict())
1332913315

1333013316
def test_transform_compose(self):
13331-
"""tests the transform on dummy data, without an env but inside a Compose."""
13332-
raise NotImplementedError
13317+
policy_odd = lambda td: td
13318+
policy_even = lambda td: td
13319+
condition = lambda td: True
13320+
transforms = Compose(
13321+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13322+
)
13323+
with pytest.raises(
13324+
RuntimeError,
13325+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13326+
):
13327+
transforms(TensorDict())
1333313328

1333413329
def test_transform_env(self):
13335-
"""tests the transform on a real env.
13336-
13337-
If possible, do not use a mock env, as bugs may go unnoticed if the dynamic is too
13338-
simplistic. A call to reset() and step() should be tested independently, ie
13339-
a check that reset produces the desired output and that step() does too.
13330+
base_env = CountingEnv(max_steps=15)
13331+
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
13332+
# Player 0
13333+
policy_odd = lambda td: td.set("action", env.action_spec.zero())
13334+
policy_even = lambda td: td.set("action", env.action_spec.one())
13335+
transforms = Compose(
13336+
StepCounter(),
13337+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13338+
)
13339+
env = base_env.append_transform(transforms)
13340+
env.check_env_specs()
13341+
r = env.rollout(1000, policy_odd, break_when_all_done=True)
13342+
assert r.shape[0] == 15
13343+
assert (r["action"] == 0).all()
13344+
assert (
13345+
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
13346+
).all()
13347+
assert r["next", "done"].any()
1334013348

13341-
"""
13342-
raise NotImplementedError
13349+
# Player 1
13350+
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
13351+
transforms = Compose(
13352+
StepCounter(),
13353+
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
13354+
)
13355+
env = base_env.append_transform(transforms)
13356+
r = env.rollout(1000, policy_even, break_when_all_done=True)
13357+
assert r.shape[0] == 16
13358+
assert (r["action"] == 1).all()
13359+
assert (
13360+
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
13361+
).all()
13362+
assert r["next", "done"].any()
1334313363

1334413364
def test_transform_model(self):
13345-
"""tests the transform before an nn.Module that reads the output."""
13346-
raise NotImplementedError
13347-
13348-
def test_transform_rb(self):
13349-
"""tests the transform when used with a replay buffer.
13350-
13351-
If your transform is not supposed to work with a replay buffer, test that
13352-
an error will be raised when called or appended to a RB.
13365+
policy_odd = lambda td: td
13366+
policy_even = lambda td: td
13367+
condition = lambda td: True
13368+
transforms = nn.Sequential(
13369+
ConditionalPolicySwitch(condition=condition, policy=policy_even),
13370+
)
13371+
with pytest.raises(
13372+
RuntimeError,
13373+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13374+
):
13375+
transforms(TensorDict())
1335313376

13354-
"""
13355-
raise NotImplementedError
13377+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
13378+
def test_transform_rb(self, rbclass):
13379+
policy_odd = lambda td: td
13380+
policy_even = lambda td: td
13381+
condition = lambda td: True
13382+
rb = rbclass(storage=LazyTensorStorage(10))
13383+
rb.append_transform(
13384+
ConditionalPolicySwitch(condition=condition, policy=policy_even)
13385+
)
13386+
rb.extend(TensorDict(batch_size=[2]))
13387+
with pytest.raises(
13388+
RuntimeError,
13389+
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
13390+
):
13391+
rb.sample(2)
1335613392

1335713393
def test_transform_inverse(self):
13358-
"""tests the inverse transform. If not applicable, simply skip this test.
13359-
13360-
If your transform is not supposed to work offline, test that
13361-
an error will be raised when called in a nn.Module.
13362-
"""
13363-
raise NotImplementedError
13394+
return
1336413395

1336513396

1336613397
if __name__ == "__main__":

torchrl/envs/custom/chess.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,7 @@ class ChessEnv(EnvBase, metaclass=_HashMeta):
153153
batch_size=torch.Size([352]),
154154
device=None,
155155
is_shared=False)
156-
157-
158-
"""
156+
""" # noqa: D301
159157

160158
_hash_table: Dict[int, str] = {}
161159
_PNG_RESTART = """[Event "?"]

torchrl/envs/transforms/transforms.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10210,3 +10210,8 @@ def _reset(
1021010210
return tensordict_reset
1021110211

1021210212
return tensordict_reset
10213+
10214+
def forward(self, tensordict: TensorDictBase) -> Any:
10215+
raise RuntimeError(
10216+
"ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional."
10217+
)

0 commit comments

Comments
 (0)