Skip to content

Commit 73c7b0a

Browse files
author
Vincent Moens
committed
[Refactor] LLM data structures
ghstack-source-id: 8483fe0 Pull Request resolved: #2834
1 parent f852b1c commit 73c7b0a

File tree

4 files changed

+152
-54
lines changed

4 files changed

+152
-54
lines changed

docs/source/reference/data.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,9 @@ efficient sampling.
11331133
get_dataloader
11341134
ConstantKLController
11351135
AdaptiveKLController
1136+
LLMData
1137+
LLMInput
1138+
LLMOutput
11361139

11371140

11381141
Utils

torchrl/data/__init__.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
ConstantKLController,
99
create_infinite_iterator,
1010
get_dataloader,
11+
LLMData,
12+
LLMInput,
13+
LLMOutput,
1114
PairwiseDataset,
1215
PromptData,
1316
PromptTensorDictTokenizer,
@@ -103,96 +106,99 @@
103106
from .utils import check_no_exclusive_keys, consolidate_spec, contains_lazy_spec
104107

105108
__all__ = [
109+
"AdaptiveKLController",
110+
"Binary",
111+
"BinaryDiscreteTensorSpec",
106112
"BinaryToDecimal",
107-
"HashToInt",
108-
"MCTSForest",
109-
"QueryModule",
110-
"RandomProjectionHash",
111-
"SipHash",
112-
"TensorDictMap",
113-
"TensorMap",
114-
"Tree",
115-
"MultiStep",
113+
"Bounded",
114+
"BoundedContinuous",
115+
"BoundedTensorSpec",
116+
"Categorical",
117+
"Choice",
118+
"Composite",
119+
"CompositeSpec",
120+
"ConstantKLController",
121+
"DEVICE_TYPING",
122+
"DiscreteTensorSpec",
116123
"Flat2TED",
117124
"FlatStorageCheckpointer",
118125
"H5Combine",
119126
"H5Split",
120127
"H5StorageCheckpointer",
128+
"HashToInt",
121129
"ImmutableDatasetWriter",
130+
"LLMData",
131+
"LLMInput",
132+
"LLMOutput",
122133
"LazyMemmapStorage",
123134
"LazyStackStorage",
135+
"LazyStackedCompositeSpec",
136+
"LazyStackedTensorSpec",
124137
"LazyTensorStorage",
125138
"ListStorage",
126139
"ListStorageCheckpointer",
140+
"MCTSForest",
141+
"MultiCategorical",
142+
"MultiDiscreteTensorSpec",
143+
"MultiOneHot",
144+
"MultiOneHotDiscreteTensorSpec",
145+
"MultiStep",
127146
"Nested2TED",
128147
"NestedStorageCheckpointer",
148+
"NonTensor",
149+
"NonTensorSpec",
150+
"OneHot",
151+
"OneHotDiscreteTensorSpec",
152+
"PairwiseDataset",
129153
"PrioritizedReplayBuffer",
130154
"PrioritizedSampler",
131155
"PrioritizedSliceSampler",
156+
"PromptData",
157+
"PromptTensorDictTokenizer",
158+
"QueryModule",
159+
"RandomProjectionHash",
132160
"RandomSampler",
133161
"RemoteTensorDictReplayBuffer",
134162
"ReplayBuffer",
135163
"ReplayBufferEnsemble",
164+
"RewardData",
165+
"RolloutFromModel",
136166
"RoundRobinWriter",
137167
"SamplerEnsemble",
138168
"SamplerWithoutReplacement",
169+
"SipHash",
139170
"SliceSampler",
140171
"SliceSamplerWithoutReplacement",
172+
"Stacked",
173+
"StackedComposite",
141174
"Storage",
142175
"StorageCheckpointerBase",
143176
"StorageEnsemble",
144177
"StorageEnsembleCheckpointer",
145178
"TED2Flat",
146179
"TED2Nested",
180+
"TensorDictMap",
147181
"TensorDictMaxValueWriter",
148182
"TensorDictPrioritizedReplayBuffer",
149183
"TensorDictReplayBuffer",
150184
"TensorDictRoundRobinWriter",
185+
"TensorDictTokenizer",
186+
"TensorMap",
187+
"TensorSpec",
151188
"TensorStorage",
152189
"TensorStorageCheckpointer",
153-
"Writer",
154-
"WriterEnsemble",
155-
"AdaptiveKLController",
156-
"ConstantKLController",
157-
"create_infinite_iterator",
158-
"get_dataloader",
159-
"PairwiseDataset",
160-
"PromptData",
161-
"PromptTensorDictTokenizer",
162-
"RewardData",
163-
"RolloutFromModel",
164-
"TensorDictTokenizer",
165190
"TokenizedDatasetLoader",
166-
"Binary",
167-
"BinaryDiscreteTensorSpec",
168-
"Bounded",
169-
"BoundedContinuous",
170-
"BoundedTensorSpec",
171-
"Categorical",
172-
"Choice",
173-
"Composite",
174-
"CompositeSpec",
175-
"DEVICE_TYPING",
176-
"DiscreteTensorSpec",
177-
"LazyStackedCompositeSpec",
178-
"LazyStackedTensorSpec",
179-
"MultiCategorical",
180-
"MultiDiscreteTensorSpec",
181-
"MultiOneHot",
182-
"MultiOneHotDiscreteTensorSpec",
183-
"NonTensor",
184-
"NonTensorSpec",
185-
"OneHot",
186-
"OneHotDiscreteTensorSpec",
187-
"Stacked",
188-
"StackedComposite",
189-
"TensorSpec",
191+
"Tree",
190192
"Unbounded",
191193
"UnboundedContinuous",
192194
"UnboundedContinuousTensorSpec",
193195
"UnboundedDiscrete",
194196
"UnboundedDiscreteTensorSpec",
197+
"Writer",
198+
"WriterEnsemble",
195199
"check_no_exclusive_keys",
196200
"consolidate_spec",
197201
"contains_lazy_spec",
202+
"create_infinite_iterator",
203+
"get_dataloader",
198204
]

torchrl/data/llm/__init__.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,21 @@
1111
)
1212
from .prompt import PromptData, PromptTensorDictTokenizer
1313
from .reward import PairwiseDataset, RewardData
14-
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel
14+
from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel, LLMData, LLMOutput, LLMInput
1515

1616
__all__ = [
17-
"create_infinite_iterator",
18-
"get_dataloader",
19-
"TensorDictTokenizer",
20-
"TokenizedDatasetLoader",
17+
"AdaptiveKLController",
18+
"ConstantKLController",
19+
"LLMData",
20+
"LLMInput",
21+
"LLMOutput",
22+
"PairwiseDataset",
2123
"PromptData",
2224
"PromptTensorDictTokenizer",
23-
"PairwiseDataset",
2425
"RewardData",
25-
"AdaptiveKLController",
26-
"ConstantKLController",
2726
"RolloutFromModel",
27+
"TensorDictTokenizer",
28+
"TokenizedDatasetLoader",
29+
"create_infinite_iterator",
30+
"get_dataloader",
2831
]

torchrl/data/llm/utils.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
import abc
88
import collections
99
import importlib
10+
from typing import TypeVar
1011

1112
import numpy as np
1213
import torch
13-
from tensordict import TensorDict
14+
from tensordict import TensorClass, TensorDict
1415
from torch import nn, Tensor
1516
from torch.nn import functional as F
1617

@@ -541,3 +542,88 @@ def step_scheduler(self):
541542
# remove all values
542543
while len(self._kl_queue):
543544
self._kl_queue.remove(self._kl_queue[0])
545+
546+
LLMInpOut = TypeVar("LLMInpOut")
547+
548+
class LLMInput(TensorClass["nocast"]):
549+
"""Represents the input to a Large Language Model (LLM).
550+
551+
Attributes:
552+
tokens (torch.Tensor): The input tokens as a tensor.
553+
attention_mask (torch.Tensor, optional): The attention mask for the input tokens. Default to `None`.
554+
token_list (list[int] | list[list[int]], optional): The input tokens as a list of integers or a list of lists of integers. Default to `None`.
555+
text (str | list[str], optional): The input text as a string or a list of strings. Default to `None`.
556+
557+
.. seealso:: :class:`~torchrl.data.LLMOutput` and :class:`~torchrl.data.LLMData`.
558+
559+
"""
560+
tokens: torch.Tensor
561+
attention_mask: torch.Tensor | None = None
562+
token_list: list[int] | list[list[int]] | None = None
563+
text: str | list[str] | None = None
564+
565+
class LLMOutput(TensorClass["nocast"]):
566+
"""Represents the output from a Large Language Model (LLM).
567+
568+
Attributes:
569+
tokens (torch.Tensor): The output tokens as a tensor.
570+
tokens_response (torch.Tensor, optional): The response tokens generated by the model. Default to `None`.
571+
572+
.. note:: the reponse is the sequence of tokens output by a model, excluding the input
573+
tokens.
574+
575+
token_list (list[int] | list[list[int]], optional): The output tokens as a list of integers or a list of lists of integers. Default to `None`.
576+
tokens_response_list (list[list[int]], optional): The response tokens generated by the model as a list of lists of integers. Default to `None`.
577+
logits (torch.Tensor, optional): The logits of the output tokens. Default to `None`.
578+
log_probs (torch.Tensor, optional): The log probabilities of the output tokens. Default to `None`.
579+
text (str | list[str], optional): The output text as a string or a list of strings. Default to `None`.
580+
581+
.. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMData`.
582+
583+
"""
584+
tokens: torch.Tensor
585+
tokens_response: torch.Tensor | None = None
586+
token_list: list[int] | list[list[int]] | None = None
587+
tokens_response_list: list[list[int]] | None = None
588+
logits: torch.Tensor | None = None
589+
log_probs: torch.Tensor | None = None
590+
text: str | list[str] | None = None
591+
592+
@classmethod
593+
def from_vllm_output(cls: type[LLMInpOut], vllm_output) -> LLMInpOut:
594+
# placeholder
595+
raise NotImplementedError
596+
597+
class LLMData(TensorClass["nocast"]):
598+
"""Represents the input or output of a Large Language Model (LLM).
599+
600+
Other algorithm-specific attributes such as `reward`, `advantages` or done states are handled automatically by the
601+
envs and, therefore, are not included in this class.
602+
603+
Attributes:
604+
tokens (torch.Tensor): The input/output tokens as a tensor.
605+
attention_mask (torch.Tensor, optional): The attention mask for the input tokens. Default to `None`.
606+
tokens_response (torch.Tensor, optional): The response tokens generated by the model. Default to `None`.
607+
608+
.. note:: the reponse is the sequence of tokens output by a model, excluding the input
609+
tokens.
610+
611+
token_list (list[int] | list[list[int]], optional): The output tokens as a list of integers or a list of lists
612+
of integers. Default to `None`.
613+
tokens_response_list (list[list[int]], optional): The response tokens generated by the model as a list of
614+
lists of integers. Default to `None`.
615+
logits (torch.Tensor, optional): The logits of the output tokens. Default to `None`.
616+
log_probs (torch.Tensor, optional): The log probabilities of the output tokens. Default to `None`.
617+
text (str | list[str], optional): The output text as a string or a list of strings. Default to `None`.
618+
619+
.. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMOutput`.
620+
621+
"""
622+
tokens: torch.Tensor
623+
tokens_response: torch.Tensor | None = None
624+
attention_mask: torch.Tensor | None = None
625+
token_list: list[int] | list[list[int]] | None = None
626+
tokens_response_list: list[list[int]] | None = None
627+
logits: torch.Tensor | None = None
628+
log_probs: torch.Tensor | None = None
629+
text: str | list[str] | None = None

0 commit comments

Comments
 (0)