Skip to content

Commit 45c6129

Browse files
authored
[Feature] RenameTransform (#964)
1 parent d909444 commit 45c6129

File tree

5 files changed

+443
-1
lines changed

5 files changed

+443
-1
lines changed

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ to be able to create this other composition:
332332
PinMemoryTransform
333333
R3MTransform
334334
RandomCropTensorDict
335+
RenameTransform
335336
Resize
336337
RewardClipping
337338
RewardScaling

test/test_transforms.py

Lines changed: 289 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,14 @@
5252
FrameSkipTransform,
5353
GrayScale,
5454
gSDENoise,
55+
InitTracker,
5556
NoopResetEnv,
5657
ObservationNorm,
5758
ParallelEnv,
5859
PinMemoryTransform,
5960
R3MTransform,
6061
RandomCropTensorDict,
62+
RenameTransform,
6163
Resize,
6264
RewardClipping,
6365
RewardScaling,
@@ -76,7 +78,7 @@
7678
from torchrl.envs.libs.gym import _has_gym, GymEnv
7779
from torchrl.envs.transforms import VecNorm
7880
from torchrl.envs.transforms.r3m import _R3MNet
79-
from torchrl.envs.transforms.transforms import _has_tv, InitTracker
81+
from torchrl.envs.transforms.transforms import _has_tv
8082
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
8183
from torchrl.envs.utils import check_env_specs, step_mdp
8284

@@ -6315,6 +6317,292 @@ def test_crop_mask(self, mask_key):
63156317
assert tensordict_crop[mask_key].all()
63166318

63176319

6320+
@pytest.mark.parametrize("create_copy", [True, False])
6321+
class TestRenameTransform(TransformBase):
6322+
def test_single_trans_env_check(self, create_copy):
6323+
env = TransformedEnv(
6324+
ContinuousActionVecMockEnv(),
6325+
RenameTransform(
6326+
[
6327+
"observation",
6328+
],
6329+
[
6330+
"stuff",
6331+
],
6332+
create_copy=create_copy,
6333+
),
6334+
)
6335+
check_env_specs(env)
6336+
env = TransformedEnv(
6337+
ContinuousActionVecMockEnv(),
6338+
RenameTransform(
6339+
["observation_orig"],
6340+
["stuff"],
6341+
["observation_orig"],
6342+
[
6343+
"stuff",
6344+
],
6345+
create_copy=create_copy,
6346+
),
6347+
)
6348+
check_env_specs(env)
6349+
6350+
def test_serial_trans_env_check(self, create_copy):
6351+
def make_env():
6352+
return TransformedEnv(
6353+
ContinuousActionVecMockEnv(),
6354+
RenameTransform(
6355+
[
6356+
"observation",
6357+
],
6358+
[
6359+
"stuff",
6360+
],
6361+
create_copy=create_copy,
6362+
),
6363+
)
6364+
6365+
env = SerialEnv(2, make_env)
6366+
check_env_specs(env)
6367+
6368+
def make_env():
6369+
return TransformedEnv(
6370+
ContinuousActionVecMockEnv(),
6371+
RenameTransform(
6372+
["observation_orig"],
6373+
["stuff"],
6374+
["observation_orig"],
6375+
[
6376+
"stuff",
6377+
],
6378+
create_copy=create_copy,
6379+
),
6380+
)
6381+
6382+
env = SerialEnv(2, make_env)
6383+
check_env_specs(env)
6384+
6385+
def test_parallel_trans_env_check(self, create_copy):
6386+
def make_env():
6387+
return TransformedEnv(
6388+
ContinuousActionVecMockEnv(),
6389+
RenameTransform(
6390+
[
6391+
"observation",
6392+
],
6393+
[
6394+
"stuff",
6395+
],
6396+
create_copy=create_copy,
6397+
),
6398+
)
6399+
6400+
env = ParallelEnv(2, make_env)
6401+
check_env_specs(env)
6402+
6403+
def make_env():
6404+
return TransformedEnv(
6405+
ContinuousActionVecMockEnv(),
6406+
RenameTransform(
6407+
["observation_orig"],
6408+
["stuff"],
6409+
["observation_orig"],
6410+
[
6411+
"stuff",
6412+
],
6413+
create_copy=create_copy,
6414+
),
6415+
)
6416+
6417+
env = ParallelEnv(2, make_env)
6418+
check_env_specs(env)
6419+
6420+
def test_trans_serial_env_check(self, create_copy):
6421+
def make_env():
6422+
return ContinuousActionVecMockEnv()
6423+
6424+
env = TransformedEnv(
6425+
SerialEnv(2, make_env),
6426+
RenameTransform(
6427+
[
6428+
"observation",
6429+
],
6430+
[
6431+
"stuff",
6432+
],
6433+
create_copy=create_copy,
6434+
),
6435+
)
6436+
check_env_specs(env)
6437+
env = TransformedEnv(
6438+
SerialEnv(2, make_env),
6439+
RenameTransform(
6440+
["observation_orig"],
6441+
["stuff"],
6442+
["observation_orig"],
6443+
[
6444+
"stuff",
6445+
],
6446+
create_copy=create_copy,
6447+
),
6448+
)
6449+
check_env_specs(env)
6450+
6451+
def test_trans_parallel_env_check(self, create_copy):
6452+
def make_env():
6453+
return ContinuousActionVecMockEnv()
6454+
6455+
env = TransformedEnv(
6456+
ParallelEnv(2, make_env),
6457+
RenameTransform(
6458+
[
6459+
"observation",
6460+
],
6461+
[
6462+
"stuff",
6463+
],
6464+
create_copy=create_copy,
6465+
),
6466+
)
6467+
check_env_specs(env)
6468+
env = TransformedEnv(
6469+
ParallelEnv(2, make_env),
6470+
RenameTransform(
6471+
["observation_orig"],
6472+
["stuff"],
6473+
["observation_orig"],
6474+
[
6475+
"stuff",
6476+
],
6477+
create_copy=create_copy,
6478+
),
6479+
)
6480+
check_env_specs(env)
6481+
6482+
@pytest.mark.parametrize("mode", ["forward", "_call"])
6483+
def test_transform_no_env(self, create_copy, mode):
6484+
t = RenameTransform(["a"], ["b"], create_copy=create_copy)
6485+
tensordict = TensorDict({"a": torch.randn(())}, [])
6486+
if mode == "forward":
6487+
t(tensordict)
6488+
elif mode == "_call":
6489+
t._call(tensordict)
6490+
else:
6491+
raise NotImplementedError
6492+
assert "b" in tensordict.keys()
6493+
if create_copy:
6494+
assert "a" in tensordict.keys()
6495+
else:
6496+
assert "a" not in tensordict.keys()
6497+
6498+
@pytest.mark.parametrize("mode", ["forward", "_call"])
6499+
def test_transform_compose(self, create_copy, mode):
6500+
t = Compose(RenameTransform(["a"], ["b"], create_copy=create_copy))
6501+
tensordict = TensorDict({"a": torch.randn(())}, [])
6502+
if mode == "forward":
6503+
t(tensordict)
6504+
elif mode == "_call":
6505+
t._call(tensordict)
6506+
else:
6507+
raise NotImplementedError
6508+
assert "b" in tensordict.keys()
6509+
if create_copy:
6510+
assert "a" in tensordict.keys()
6511+
else:
6512+
assert "a" not in tensordict.keys()
6513+
6514+
def test_transform_env(self, create_copy):
6515+
env = TransformedEnv(
6516+
ContinuousActionVecMockEnv(),
6517+
RenameTransform(
6518+
[
6519+
"observation",
6520+
],
6521+
[
6522+
"stuff",
6523+
],
6524+
create_copy=create_copy,
6525+
),
6526+
)
6527+
r = env.rollout(3)
6528+
if create_copy:
6529+
assert "observation" in r.keys()
6530+
assert ("next", "observation") in r.keys(True)
6531+
else:
6532+
assert "observation" not in r.keys()
6533+
assert ("next", "observation") not in r.keys(True)
6534+
assert "stuff" in r.keys()
6535+
assert ("next", "stuff") in r.keys(True)
6536+
6537+
env = TransformedEnv(
6538+
ContinuousActionVecMockEnv(),
6539+
RenameTransform(
6540+
["observation_orig"],
6541+
["stuff"],
6542+
["observation_orig"],
6543+
[
6544+
"stuff",
6545+
],
6546+
create_copy=create_copy,
6547+
),
6548+
)
6549+
r = env.rollout(3)
6550+
if create_copy:
6551+
assert "observation_orig" in r.keys()
6552+
assert ("next", "observation_orig") in r.keys(True)
6553+
else:
6554+
assert "observation_orig" not in r.keys()
6555+
assert ("next", "observation_orig") not in r.keys(True)
6556+
assert "stuff" in r.keys()
6557+
assert ("next", "stuff") in r.keys(True)
6558+
6559+
def test_transform_model(self, create_copy):
6560+
t = RenameTransform(["a"], ["b"], create_copy=create_copy)
6561+
tensordict = TensorDict({"a": torch.randn(())}, [])
6562+
model = nn.Sequential(t)
6563+
model(tensordict)
6564+
assert "b" in tensordict.keys()
6565+
if create_copy:
6566+
assert "a" in tensordict.keys()
6567+
else:
6568+
assert "a" not in tensordict.keys()
6569+
6570+
@pytest.mark.parametrize(
6571+
"inverse",
6572+
[
6573+
False,
6574+
True,
6575+
],
6576+
)
6577+
def test_transform_rb(self, create_copy, inverse):
6578+
if not inverse:
6579+
t = RenameTransform(["a"], ["b"], create_copy=create_copy)
6580+
tensordict = TensorDict({"a": torch.randn(())}, []).expand(10)
6581+
else:
6582+
t = RenameTransform(["a"], ["b"], ["a"], ["b"], create_copy=create_copy)
6583+
tensordict = TensorDict({"b": torch.randn(())}, []).expand(10)
6584+
rb = ReplayBuffer(LazyTensorStorage(20))
6585+
rb.append_transform(t)
6586+
rb.extend(tensordict)
6587+
assert "a" in rb._storage._storage.keys()
6588+
sample = rb.sample(2)
6589+
if create_copy:
6590+
assert "a" in sample.keys()
6591+
else:
6592+
assert "a" not in sample.keys()
6593+
assert "b" in sample.keys()
6594+
6595+
def test_transform_inverse(self, create_copy):
6596+
t = RenameTransform(["a"], ["b"], ["a"], ["b"], create_copy=create_copy)
6597+
tensordict = TensorDict({"b": torch.randn(())}, []).expand(10)
6598+
tensordict = t.inv(tensordict)
6599+
assert "a" in tensordict.keys()
6600+
if create_copy:
6601+
assert "b" in tensordict.keys()
6602+
else:
6603+
assert "b" not in tensordict.keys()
6604+
6605+
63186606
class TestInitTracker(TransformBase):
63196607
def test_single_trans_env_check(self):
63206608
env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2])

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
PinMemoryTransform,
2929
R3MTransform,
3030
RandomCropTensorDict,
31+
RenameTransform,
3132
Resize,
3233
RewardClipping,
3334
RewardScaling,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ObservationTransform,
2525
PinMemoryTransform,
2626
RandomCropTensorDict,
27+
RenameTransform,
2728
Resize,
2829
RewardClipping,
2930
RewardScaling,

0 commit comments

Comments
 (0)