Skip to content

Commit f70bc1b

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 42ce732 + 043d578 commit f70bc1b

File tree

7 files changed

+446
-23
lines changed

7 files changed

+446
-23
lines changed

test/mocking_classes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,17 +1070,20 @@ def _step(
10701070

10711071
class CountingEnvWithString(CountingEnv):
10721072
def __init__(self, *args, **kwargs):
1073+
self.max_size = kwargs.pop("max_size", 30)
1074+
self.min_size = kwargs.pop("min_size", 4)
10731075
super().__init__(*args, **kwargs)
10741076
self.observation_spec.set(
10751077
"string",
10761078
NonTensor(
10771079
shape=self.batch_size,
10781080
device=self.device,
1081+
example_data=self.get_random_string(),
10791082
),
10801083
)
10811084

10821085
def get_random_string(self):
1083-
size = random.randint(4, 30)
1086+
size = random.randint(self.min_size, self.max_size)
10841087
return "".join(random.choice(string.ascii_lowercase) for _ in range(size))
10851088

10861089
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:

test/test_specs.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,12 +1402,13 @@ def test_multionehot(self, shape1, shape2):
14021402
assert spec2.zero().shape == spec2.shape
14031403

14041404
def test_non_tensor(self):
1405-
spec = NonTensor((3, 4), device="cpu")
1405+
spec = NonTensor((3, 4), device="cpu", example_data="example_data")
14061406
assert (
14071407
spec.expand(2, 3, 4)
14081408
== spec.expand((2, 3, 4))
1409-
== NonTensor((2, 3, 4), device="cpu")
1409+
== NonTensor((2, 3, 4), device="cpu", example_data="example_data")
14101410
)
1411+
assert spec.expand(2, 3, 4).example_data == "example_data"
14111412

14121413
@pytest.mark.parametrize("shape1", [None, (), (5,)])
14131414
@pytest.mark.parametrize("shape2", [(), (10,)])
@@ -1607,9 +1608,10 @@ def test_multionehot(
16071608
assert spec is not spec.clone()
16081609

16091610
def test_non_tensor(self):
1610-
spec = NonTensor(shape=(3, 4), device="cpu")
1611+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
16111612
assert spec.clone() == spec
16121613
assert spec.clone() is not spec
1614+
assert spec.clone().example_data == "example_data"
16131615

16141616
@pytest.mark.parametrize("shape1", [None, (), (5,)])
16151617
def test_onehot(
@@ -1840,9 +1842,10 @@ def test_multionehot(
18401842
spec.unbind(-1)
18411843

18421844
def test_non_tensor(self):
1843-
spec = NonTensor(shape=(3, 4), device="cpu")
1845+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
18441846
assert spec.unbind(1)[0] == spec[:, 0]
18451847
assert spec.unbind(1)[0] is not spec[:, 0]
1848+
assert spec.unbind(1)[0].example_data == "example_data"
18461849

18471850
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
18481851
def test_onehot(
@@ -2001,8 +2004,9 @@ def test_multionehot(self, shape1, device):
20012004
assert spec.to(device).device == device
20022005

20032006
def test_non_tensor(self, device):
2004-
spec = NonTensor(shape=(3, 4), device="cpu")
2007+
spec = NonTensor(shape=(3, 4), device="cpu", example_data="example_data")
20052008
assert spec.to(device).device == device
2009+
assert spec.to(device).example_data == "example_data"
20062010

20072011
@pytest.mark.parametrize("shape1", [(5,), (5, 6)])
20082012
def test_onehot(self, shape1, device):
@@ -2262,13 +2266,14 @@ def test_stack_multionehot_zero(self, shape, stack_dim):
22622266
assert r.shape == c.shape
22632267

22642268
def test_stack_non_tensor(self, shape, stack_dim):
2265-
spec0 = NonTensor(shape=shape, device="cpu")
2266-
spec1 = NonTensor(shape=shape, device="cpu")
2269+
spec0 = NonTensor(shape=shape, device="cpu", example_data="example_data")
2270+
spec1 = NonTensor(shape=shape, device="cpu", example_data="example_data")
22672271
new_spec = torch.stack([spec0, spec1], stack_dim)
22682272
shape_insert = list(shape)
22692273
shape_insert.insert(stack_dim, 2)
22702274
assert new_spec.shape == torch.Size(shape_insert)
22712275
assert new_spec.device == torch.device("cpu")
2276+
assert new_spec.example_data == "example_data"
22722277

22732278
def test_stack_onehot(self, shape, stack_dim):
22742279
n = 5
@@ -3642,10 +3647,18 @@ def test_expand(self):
36423647

36433648
class TestNonTensorSpec:
36443649
def test_sample(self):
3645-
nts = NonTensor(shape=(3, 4))
3650+
nts = NonTensor(shape=(3, 4), example_data="example_data")
36463651
assert nts.one((2,)).shape == (2, 3, 4)
36473652
assert nts.rand((2,)).shape == (2, 3, 4)
36483653
assert nts.zero((2,)).shape == (2, 3, 4)
3654+
assert nts.one((2,)).data == "example_data"
3655+
assert nts.rand((2,)).data == "example_data"
3656+
assert nts.zero((2,)).data == "example_data"
3657+
3658+
def test_example_data_ineq(self):
3659+
nts0 = NonTensor(shape=(3, 4), example_data="example_data")
3660+
nts1 = NonTensor(shape=(3, 4), example_data="example_data 2")
3661+
assert nts0 != nts1
36493662

36503663

36513664
@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")

test/test_transforms.py

Lines changed: 218 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
TargetReturn,
148148
TensorDictPrimer,
149149
TimeMaxPool,
150+
Tokenizer,
150151
ToTensorImage,
151152
TrajCounter,
152153
TransformedEnv,
@@ -2420,7 +2421,223 @@ def test_transform_rb(self, rbclass):
24202421
assert ("next", "observation") in td.keys(True)
24212422

24222423
def test_transform_inverse(self):
2423-
raise pytest.skip("No inverse for Hash")
2424+
env = CountingEnv()
2425+
env = env.append_transform(
2426+
Hash(
2427+
in_keys=[],
2428+
out_keys=[],
2429+
in_keys_inv=["action"],
2430+
out_keys_inv=["action_hash"],
2431+
)
2432+
)
2433+
assert "action_hash" in env.action_keys
2434+
r = env.rollout(3)
2435+
env.check_env_specs()
2436+
assert "action_hash" in r
2437+
assert isinstance(r[0]["action_hash"], torch.Tensor)
2438+
2439+
2440+
class TestTokenizer(TransformBase):
2441+
@pytest.mark.parametrize("datatype", ["str", "NonTensorStack"])
2442+
def test_transform_no_env(self, datatype):
2443+
if datatype == "str":
2444+
obs = "abcdefg"
2445+
elif datatype == "NonTensorStack":
2446+
obs = torch.stack(
2447+
[
2448+
NonTensorData(data="abcde"),
2449+
NonTensorData(data="fghij"),
2450+
NonTensorData(data="klmno"),
2451+
]
2452+
)
2453+
else:
2454+
raise RuntimeError(f"please add a test case for datatype {datatype}")
2455+
2456+
td = TensorDict(
2457+
{
2458+
"observation": obs,
2459+
}
2460+
)
2461+
2462+
t = Tokenizer(in_keys=["observation"], out_keys=["tokens"])
2463+
td_tokenized = t(td)
2464+
t_inv = Tokenizer([], [], in_keys_inv=["tokens"], out_keys_inv=["observation"])
2465+
td_recon = t_inv.inv(td_tokenized.clone().exclude("observation"))
2466+
assert td_tokenized.get("observation") is td.get("observation")
2467+
assert td_recon["observation"] == td["observation"]
2468+
2469+
@pytest.mark.parametrize("datatype", ["str"])
2470+
def test_single_trans_env_check(self, datatype):
2471+
if datatype == "str":
2472+
t = Tokenizer(
2473+
in_keys=["string"],
2474+
out_keys=["tokens"],
2475+
max_length=5,
2476+
)
2477+
base_env = CountingEnvWithString(max_size=4, min_size=4)
2478+
env = TransformedEnv(base_env, t)
2479+
check_env_specs(env, return_contiguous=False)
2480+
2481+
@pytest.mark.parametrize("datatype", ["str"])
2482+
def test_serial_trans_env_check(self, datatype):
2483+
def make_env():
2484+
if datatype == "str":
2485+
t = Tokenizer(
2486+
in_keys=["string"],
2487+
out_keys=["tokens"],
2488+
max_length=5,
2489+
)
2490+
base_env = CountingEnvWithString(max_size=4, min_size=4)
2491+
2492+
return TransformedEnv(base_env, t)
2493+
2494+
env = SerialEnv(2, make_env)
2495+
check_env_specs(env, return_contiguous=False)
2496+
2497+
@pytest.mark.parametrize("datatype", ["str"])
2498+
def test_parallel_trans_env_check(self, maybe_fork_ParallelEnv, datatype):
2499+
def make_env():
2500+
if datatype == "str":
2501+
t = Tokenizer(
2502+
in_keys=["string"],
2503+
out_keys=["tokens"],
2504+
max_length=5,
2505+
)
2506+
base_env = CountingEnvWithString(max_size=4, min_size=4)
2507+
return TransformedEnv(base_env, t)
2508+
2509+
env = maybe_fork_ParallelEnv(2, make_env)
2510+
try:
2511+
check_env_specs(env, return_contiguous=False)
2512+
finally:
2513+
try:
2514+
env.close()
2515+
except RuntimeError:
2516+
pass
2517+
2518+
@pytest.mark.parametrize("datatype", ["str"])
2519+
def test_trans_serial_env_check(self, datatype):
2520+
if datatype == "str":
2521+
t = Tokenizer(
2522+
in_keys=["string"],
2523+
out_keys=["tokens"],
2524+
max_length=5,
2525+
)
2526+
base_env = partial(CountingEnvWithString, max_size=4, min_size=4)
2527+
2528+
env = TransformedEnv(SerialEnv(2, base_env), t)
2529+
check_env_specs(env, return_contiguous=False)
2530+
2531+
@pytest.mark.parametrize("datatype", ["str"])
2532+
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv, datatype):
2533+
if datatype == "str":
2534+
t = Tokenizer(
2535+
in_keys=["string"],
2536+
out_keys=["tokens"],
2537+
max_length=5,
2538+
)
2539+
base_env = partial(CountingEnvWithString, max_size=4, min_size=4)
2540+
2541+
env = TransformedEnv(maybe_fork_ParallelEnv(2, base_env), t)
2542+
try:
2543+
check_env_specs(env, return_contiguous=False)
2544+
finally:
2545+
try:
2546+
env.close()
2547+
except RuntimeError:
2548+
pass
2549+
2550+
@pytest.mark.parametrize("datatype", ["str"])
2551+
def test_transform_compose(self, datatype):
2552+
if datatype == "str":
2553+
obs = "abcdefg"
2554+
2555+
td = TensorDict(
2556+
{
2557+
"observation": obs,
2558+
}
2559+
)
2560+
t = Tokenizer(
2561+
in_keys=["observation"],
2562+
out_keys=["tokens"],
2563+
max_length=5,
2564+
)
2565+
t = Compose(t)
2566+
td_tokenized = t(td)
2567+
2568+
assert td_tokenized["observation"] is td["observation"]
2569+
assert td_tokenized["tokens"] == t[0].tokenizer(obs, return_tensor="pt")
2570+
2571+
# TODO
2572+
def test_transform_model(self):
2573+
t = Hash(
2574+
in_keys=[("next", "observation"), ("observation",)],
2575+
out_keys=[("next", "hashing"), ("hashing",)],
2576+
hash_fn=hash,
2577+
)
2578+
model = nn.Sequential(t, nn.Identity())
2579+
td = TensorDict(
2580+
{("next", "observation"): torch.randn(3), "observation": torch.randn(3)}, []
2581+
)
2582+
td_out = model(td)
2583+
assert ("next", "hashing") in td_out.keys(True)
2584+
assert ("hashing",) in td_out.keys(True)
2585+
assert td_out["next", "hashing"] == hash(td["next", "observation"])
2586+
assert td_out["hashing"] == hash(td["observation"])
2587+
2588+
@pytest.mark.skipif(not _has_gym, reason="Gym not found")
2589+
def test_transform_env(self):
2590+
t = Hash(
2591+
in_keys=["observation"],
2592+
out_keys=["hashing"],
2593+
hash_fn=hash,
2594+
)
2595+
env = TransformedEnv(GymEnv(PENDULUM_VERSIONED()), t)
2596+
assert env.observation_spec["hashing"]
2597+
assert "observation" in env.observation_spec
2598+
assert "observation" in env.base_env.observation_spec
2599+
check_env_specs(env)
2600+
2601+
@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
2602+
def test_transform_rb(self, rbclass):
2603+
t = Hash(
2604+
in_keys=[("next", "observation"), ("observation",)],
2605+
out_keys=[("next", "hashing"), ("hashing",)],
2606+
hash_fn=lambda x: [hash(x[0]), hash(x[1])],
2607+
)
2608+
rb = rbclass(storage=LazyTensorStorage(10))
2609+
rb.append_transform(t)
2610+
td = TensorDict(
2611+
{
2612+
"observation": torch.randn(3, 4),
2613+
"next": TensorDict(
2614+
{"observation": torch.randn(3, 4)},
2615+
[],
2616+
),
2617+
},
2618+
[],
2619+
).expand(10)
2620+
rb.extend(td)
2621+
td = rb.sample(2)
2622+
assert "hashing" in td.keys()
2623+
assert "observation" in td.keys()
2624+
assert ("next", "observation") in td.keys(True)
2625+
2626+
def test_transform_inverse(self):
2627+
env = CountingEnv()
2628+
env = env.append_transform(
2629+
Hash(
2630+
in_keys=[],
2631+
out_keys=[],
2632+
in_keys_inv=["action"],
2633+
out_keys_inv=["action_hash"],
2634+
)
2635+
)
2636+
assert "action_hash" in env.action_keys
2637+
r = env.rollout(3)
2638+
env.check_env_specs()
2639+
assert "action_hash" in r
2640+
assert isinstance(r[0]["action_hash"], torch.Tensor)
24242641

24252642

24262643
class TestStack(TransformBase):

torchrl/data/tensor_specs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2452,6 +2452,8 @@ class NonTensor(TensorSpec):
24522452
(same will go for :meth:`.zero` and :meth:`.one`).
24532453
"""
24542454

2455+
example_data: Any = None
2456+
24552457
def __init__(
24562458
self,
24572459
shape: Union[torch.Size, int] = _DEFAULT_SHAPE,
@@ -2470,6 +2472,11 @@ def __init__(
24702472
)
24712473
self.example_data = example_data
24722474

2475+
def __eq__(self, other):
2476+
eq = super().__eq__(other)
2477+
eq = eq & (self.example_data == getattr(other, "example_data", None))
2478+
return eq
2479+
24732480
def cardinality(self) -> Any:
24742481
raise RuntimeError("Cannot enumerate a NonTensorSpec.")
24752482

@@ -2555,6 +2562,16 @@ def expand(self, *shape):
25552562
shape=shape, device=self.device, dtype=None, example_data=self.example_data
25562563
)
25572564

2565+
def unsqueeze(self, dim: int) -> NonTensor:
2566+
unsq = super().unsqueeze(dim=dim)
2567+
unsq.example_data = self.example_data
2568+
return unsq
2569+
2570+
def squeeze(self, dim: int | None = None) -> NonTensor:
2571+
sq = super().squeeze(dim=dim)
2572+
sq.example_data = self.example_data
2573+
return sq
2574+
25582575
def _reshape(self, shape):
25592576
return self.__class__(
25602577
shape=shape,

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
TargetReturn,
9595
TensorDictPrimer,
9696
TimeMaxPool,
97+
Tokenizer,
9798
ToTensorImage,
9899
TrajCounter,
99100
Transform,

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TargetReturn,
5656
TensorDictPrimer,
5757
TimeMaxPool,
58+
Tokenizer,
5859
ToTensorImage,
5960
TrajCounter,
6061
Transform,

0 commit comments

Comments
 (0)