Skip to content

Commit 30d21e5

Browse files
author
Vincent Moens
committed
[Feature] LLMHashingEnv
ghstack-source-id: d1a20ec Pull Request resolved: #2635
1 parent 57dc25a commit 30d21e5

File tree

9 files changed

+293
-17
lines changed

9 files changed

+293
-17
lines changed

docs/source/reference/envs.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ TorchRL offers a series of custom built-in environments.
347347

348348
PendulumEnv
349349
TicTacToeEnv
350+
LLMHashingEnv
351+
350352

351353
Multi-agent environments
352354
------------------------

docs/source/reference/trainers.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Hooks can be split into 3 categories: **data processing** (``"batch_process"`` a
7979

8080
- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
8181
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
82-
logger (``LogScaler``) and such. Hooks should return a dictionary (or a None value) containing the
82+
logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the
8383
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
8484
should be displayed on the progression bar printed on the training log.
8585

@@ -174,7 +174,7 @@ Trainer and hooks
174174
BatchSubSampler
175175
ClearCudaCache
176176
CountFramesLog
177-
LogScaler
177+
LogScalar
178178
OptimizerHook
179179
LogValidationReward
180180
ReplayBufferTrainer

test/test_env.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import functools
99
import gc
1010
import os.path
11+
import random
1112
import re
1213
from collections import defaultdict
1314
from functools import partial
@@ -114,6 +115,7 @@
114115
DoubleToFloat,
115116
EnvBase,
116117
EnvCreator,
118+
LLMHashingEnv,
117119
ParallelEnv,
118120
PendulumEnv,
119121
SerialEnv,
@@ -3419,6 +3421,29 @@ def test_pendulum_env(self, device):
34193421
r = env.rollout(10, tensordict=TensorDict(batch_size=[5], device=device))
34203422
assert r.shape == torch.Size((5, 10))
34213423

3424+
def test_llm_hashing_env(self):
3425+
vocab_size = 5
3426+
3427+
class Tokenizer:
3428+
def __call__(self, obj):
3429+
return torch.randint(vocab_size, (len(obj.split(" ")),)).tolist()
3430+
3431+
def decode(self, obj):
3432+
words = ["apple", "banana", "cherry", "date", "elderberry"]
3433+
return " ".join(random.choice(words) for _ in obj)
3434+
3435+
def batch_decode(self, obj):
3436+
return [self.decode(_obj) for _obj in obj]
3437+
3438+
def encode(self, obj):
3439+
return self(obj)
3440+
3441+
tokenizer = Tokenizer()
3442+
env = LLMHashingEnv(tokenizer=tokenizer, vocab_size=vocab_size)
3443+
td = env.make_tensordict("some sentence")
3444+
assert isinstance(td, TensorDict)
3445+
env.check_env_specs(tensordict=td)
3446+
34223447

34233448
@pytest.mark.parametrize("device", [None, *get_default_devices()])
34243449
@pytest.mark.parametrize("env_device", [None, *get_default_devices()])

torchrl/data/map/tree.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,35 +135,40 @@ def make_node(
135135
def full_observation_spec(self):
136136
"""The observation spec of the tree.
137137
138-
This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`."""
138+
This is an alias for `Tree.specs['output_spec', 'full_observation_spec']`.
139+
"""
139140
return self.specs["output_spec", "full_observation_spec"]
140141

141142
@property
142143
def full_reward_spec(self):
143144
"""The reward spec of the tree.
144145
145-
This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`."""
146+
This is an alias for `Tree.specs['output_spec', 'full_reward_spec']`.
147+
"""
146148
return self.specs["output_spec", "full_reward_spec"]
147149

148150
@property
149151
def full_done_spec(self):
150152
"""The done spec of the tree.
151153
152-
This is an alias for `Tree.specs['output_spec', 'full_done_spec']`."""
154+
This is an alias for `Tree.specs['output_spec', 'full_done_spec']`.
155+
"""
153156
return self.specs["output_spec", "full_done_spec"]
154157

155158
@property
156159
def full_state_spec(self):
157160
"""The state spec of the tree.
158161
159-
This is an alias for `Tree.specs['input_spec', 'full_state_spec']`."""
162+
This is an alias for `Tree.specs['input_spec', 'full_state_spec']`.
163+
"""
160164
return self.specs["input_spec", "full_state_spec"]
161165

162166
@property
163167
def full_action_spec(self):
164168
"""The action spec of the tree.
165169
166-
This is an alias for `Tree.specs['input_spec', 'full_action_spec']`."""
170+
This is an alias for `Tree.specs['input_spec', 'full_action_spec']`.
171+
"""
167172
return self.specs["input_spec", "full_action_spec"]
168173

169174
@property

torchrl/envs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .batched_envs import ParallelEnv, SerialEnv
77
from .common import EnvBase, EnvMetaData, make_tensordict
8-
from .custom import PendulumEnv, TicTacToeEnv
8+
from .custom import LLMHashingEnv, PendulumEnv, TicTacToeEnv
99
from .env_creator import env_creator, EnvCreator, get_env_metadata
1010
from .gym_like import default_info_dict_reader, GymLikeEnv
1111
from .libs import (

torchrl/envs/common.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,14 @@
1414
import numpy as np
1515
import torch
1616
import torch.nn as nn
17-
from tensordict import LazyStackedTensorDict, TensorDictBase, unravel_key
18-
from tensordict.utils import NestedKey
17+
from tensordict import (
18+
is_tensor_collection,
19+
LazyStackedTensorDict,
20+
TensorDictBase,
21+
unravel_key,
22+
)
23+
from tensordict.base import _is_leaf_nontensor
24+
from tensordict.utils import is_non_tensor, NestedKey
1925
from torchrl._utils import (
2026
_ends_with,
2127
_make_ordinal_device,
@@ -25,7 +31,13 @@
2531
seed_generator,
2632
)
2733

28-
from torchrl.data.tensor_specs import Categorical, Composite, TensorSpec, Unbounded
34+
from torchrl.data.tensor_specs import (
35+
Categorical,
36+
Composite,
37+
NonTensor,
38+
TensorSpec,
39+
Unbounded,
40+
)
2941
from torchrl.data.utils import DEVICE_TYPING
3042
from torchrl.envs.utils import (
3143
_make_compatible_policy,
@@ -430,7 +442,6 @@ def auto_specs_(
430442
done_key: NestedKey | List[NestedKey] | None = None,
431443
observation_key: NestedKey | List[NestedKey] = "observation",
432444
reward_key: NestedKey | List[NestedKey] = "reward",
433-
batch_size: torch.Size | None = None,
434445
):
435446
"""Automatically sets the specifications (specs) of the environment based on a random rollout using a given policy.
436447
@@ -484,6 +495,7 @@ def auto_specs_(
484495
tensordict2,
485496
named=True,
486497
nested_keys=True,
498+
is_leaf=_is_leaf_nontensor,
487499
)
488500
input_spec = Composite(input_spec_stack, batch_size=batch_size)
489501
if not self.batch_locked and batch_size != self.batch_size:
@@ -501,6 +513,7 @@ def auto_specs_(
501513
nexts_1,
502514
named=True,
503515
nested_keys=True,
516+
is_leaf=_is_leaf_nontensor,
504517
)
505518

506519
output_spec = Composite(output_spec_stack, batch_size=batch_size)
@@ -523,7 +536,8 @@ def auto_specs_(
523536
full_observation_spec = output_spec.separates(*observation_key, default=None)
524537
if not output_spec.is_empty(recurse=True):
525538
raise RuntimeError(
526-
f"Keys {list(output_spec.keys(True, True))} are unaccounted for."
539+
f"Keys {list(output_spec.keys(True, True))} are unaccounted for. "
540+
f"Make sure you have passed all the leaf names to the auto_specs_ method."
527541
)
528542

529543
if full_action_spec is not None:
@@ -541,6 +555,8 @@ def auto_specs_(
541555

542556
@wraps(check_env_specs_func)
543557
def check_env_specs(self, *args, **kwargs):
558+
return_contiguous = kwargs.pop("return_contiguous", not self._has_dynamic_specs)
559+
kwargs["return_contiguous"] = return_contiguous
544560
return check_env_specs_func(self, *args, **kwargs)
545561

546562
check_env_specs.__doc__ = check_env_specs_func.__doc__
@@ -3206,7 +3222,10 @@ def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
32063222
"""
32073223
if self._simple_done:
32083224
done = tensordict._get_str("done", default=None)
3209-
any_done = done.any()
3225+
if done is not None:
3226+
any_done = done.any()
3227+
else:
3228+
any_done = False
32103229
if any_done:
32113230
tensordict._set_str(
32123231
"_reset",
@@ -3572,6 +3591,12 @@ def _has_dynamic_specs(spec: Composite):
35723591

35733592

35743593
def _tensor_to_spec(name, leaf, leaf_compare=None, *, stack):
3594+
if not (isinstance(leaf, torch.Tensor) or is_tensor_collection(leaf)):
3595+
stack[name] = NonTensor(shape=())
3596+
return
3597+
elif is_non_tensor(leaf):
3598+
stack[name] = NonTensor(shape=leaf.shape)
3599+
return
35753600
shape = leaf.shape
35763601
if leaf_compare is not None:
35773602
shape_compare = leaf_compare.shape

torchrl/envs/custom/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .llm import LLMHashingEnv
67
from .pendulum import PendulumEnv
78
from .tictactoeenv import TicTacToeEnv

0 commit comments

Comments
 (0)