Skip to content

Commit f852b1c

Browse files
author
Vincent Moens
committed
[Refactor] Rename RLHF files to LLM
ghstack-source-id: ff99de9 Pull Request resolved: #2833
1 parent 413571b commit f852b1c

File tree

26 files changed

+934
-849
lines changed

26 files changed

+934
-849
lines changed

examples/rlhf/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr
1+
from torchrl.data.llm.prompt import get_prompt_dataloader_tldr
22

33
__all__ = ["get_prompt_dataloader_tldr"]

examples/rlhf/models/reward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tensordict.nn import TensorDictModule
99
from torchrl._utils import logger as torchrl_logger
1010

11-
from torchrl.modules.models.rlhf import GPT2RewardModel
11+
from torchrl.modules.models.llm import GPT2RewardModel
1212

1313

1414
def init_reward_model(

examples/rlhf/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from torch.optim.lr_scheduler import CosineAnnealingLR
1818
from torchrl._utils import logger as torchrl_logger
1919

20-
from torchrl.data.rlhf.dataset import get_dataloader
21-
from torchrl.data.rlhf.prompt import PromptData
20+
from torchrl.data.llm.dataset import get_dataloader
21+
from torchrl.data.llm.prompt import PromptData
2222
from utils import get_file_logger, resolve_name_or_path, setup
2323

2424

examples/rlhf/train_reward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from models.reward import init_reward_model
1010
from torch.optim.lr_scheduler import CosineAnnealingLR
1111
from torchrl._utils import logger as torchrl_logger
12-
from torchrl.data.rlhf.dataset import get_dataloader
13-
from torchrl.data.rlhf.reward import PairwiseDataset
12+
from torchrl.data.llm.dataset import get_dataloader
13+
from torchrl.data.llm.reward import PairwiseDataset
1414
from utils import get_file_logger, resolve_name_or_path, setup
1515

1616

examples/rlhf/train_rlhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import hydra
77
import torch
88
from models.actor_critic import init_actor_critic
9-
from torchrl.data.rlhf.utils import AdaptiveKLController, RolloutFromModel
9+
from torchrl.data.llm.utils import AdaptiveKLController, RolloutFromModel
1010

1111
from torchrl.record.loggers import get_logger
1212

examples/rlhf/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
TensorDictReplayBuffer,
2323
TensorStorage,
2424
)
25+
from torchrl.data.llm.dataset import get_dataloader
26+
from torchrl.data.llm.prompt import PromptData
2527
from torchrl.data.replay_buffers import SamplerWithoutReplacement
26-
from torchrl.data.rlhf.dataset import get_dataloader
27-
from torchrl.data.rlhf.prompt import PromptData
2828
from torchrl.objectives import ClipPPOLoss
2929
from torchrl.objectives.value import GAE
3030

test/assets/generate.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
"""Script used to generate the mini datasets."""
77
import multiprocessing as mp
8+
import pathlib
89

910
try:
1011
mp.set_start_method("spawn")
@@ -14,8 +15,8 @@
1415

1516
from datasets import Dataset, DatasetDict, load_dataset
1617

17-
from torchrl.data.rlhf.dataset import get_dataloader
18-
from torchrl.data.rlhf.prompt import PromptData
18+
from torchrl.data.llm.dataset import get_dataloader
19+
from torchrl.data.llm.prompt import PromptData
1920

2021

2122
def generate_small_dataset(comparison=True):
@@ -42,7 +43,7 @@ def get_minibatch():
4243
batch_size=16,
4344
block_size=33,
4445
tensorclass_type=PromptData,
45-
dataset_name="../datasets_mini/openai_summarize_tldr",
46+
dataset_name=f"{pathlib.Path(__file__).parent}/../datasets_mini/openai_summarize_tldr",
4647
device="cpu",
4748
num_workers=2,
4849
infinite=False,
@@ -52,9 +53,12 @@ def get_minibatch():
5253
root_dir=tmpdir,
5354
)
5455
for data in dl:
55-
data = data.clone().memmap_("test/datasets_mini/tldr_batch/")
56+
data = data.clone().memmap_(
57+
f"{pathlib.Path(__file__).parent}/../datasets_mini/tldr_batch/"
58+
)
5659
break
5760

5861

5962
if __name__ == "__main__":
63+
generate_small_dataset(False)
6064
get_minibatch()

test/assets/tldr_batch.zip

2 Bytes
Binary file not shown.

test/test_actors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from torch import distributions as dist, nn
1616
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
17-
from torchrl.data.rlhf.dataset import _has_transformers
17+
from torchrl.data.llm.dataset import _has_transformers
1818
from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal
1919
from torchrl.modules.tensordict_module.actors import (
2020
_process_action_space_spec,

test/test_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
6262
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper
6363
from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv
64-
from torchrl.envs.transforms.rlhf import as_padded_tensor
64+
from torchrl.envs.transforms.llm import as_padded_tensor
6565
from torchrl.envs.transforms.transforms import (
6666
AutoResetEnv,
6767
AutoResetTransform,

test/test_rlhf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@
2121
TensorDictBase,
2222
)
2323
from tensordict.nn import TensorDictModule
24-
from torchrl.data.rlhf import TensorDictTokenizer
25-
from torchrl.data.rlhf.dataset import (
24+
from torchrl.data.llm import TensorDictTokenizer
25+
from torchrl.data.llm.dataset import (
2626
_has_datasets,
2727
_has_transformers,
2828
get_dataloader,
2929
TokenizedDatasetLoader,
3030
)
31-
from torchrl.data.rlhf.prompt import PromptData, PromptTensorDictTokenizer
32-
from torchrl.data.rlhf.reward import PairwiseDataset, pre_tokenization_hook
33-
from torchrl.data.rlhf.utils import RolloutFromModel
34-
from torchrl.modules.models.rlhf import GPT2RewardModel
31+
from torchrl.data.llm.prompt import PromptData, PromptTensorDictTokenizer
32+
from torchrl.data.llm.reward import PairwiseDataset, pre_tokenization_hook
33+
from torchrl.data.llm.utils import RolloutFromModel
34+
from torchrl.modules.models.llm import GPT2RewardModel
3535

3636
if os.getenv("PYTORCH_TEST_FBCODE"):
3737
from pytorch.rl.test._utils_internal import get_default_devices

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@
117117
from torchrl.envs.libs.gym import _has_gym, GymEnv, set_gym_backend
118118
from torchrl.envs.libs.unity_mlagents import _has_unity_mlagents
119119
from torchrl.envs.transforms import VecNorm
120+
from torchrl.envs.transforms.llm import KLRewardTransform
120121
from torchrl.envs.transforms.r3m import _R3MNet
121-
from torchrl.envs.transforms.rlhf import KLRewardTransform
122122
from torchrl.envs.transforms.transforms import (
123123
_has_tv,
124124
ActionDiscretizer,

torchrl/data/__init__.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,19 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .llm import (
7+
AdaptiveKLController,
8+
ConstantKLController,
9+
create_infinite_iterator,
10+
get_dataloader,
11+
PairwiseDataset,
12+
PromptData,
13+
PromptTensorDictTokenizer,
14+
RewardData,
15+
RolloutFromModel,
16+
TensorDictTokenizer,
17+
TokenizedDatasetLoader,
18+
)
619
from .map import (
720
BinaryToDecimal,
821
HashToInt,
@@ -56,19 +69,6 @@
5669
Writer,
5770
WriterEnsemble,
5871
)
59-
from .rlhf import (
60-
AdaptiveKLController,
61-
ConstantKLController,
62-
create_infinite_iterator,
63-
get_dataloader,
64-
PairwiseDataset,
65-
PromptData,
66-
PromptTensorDictTokenizer,
67-
RewardData,
68-
RolloutFromModel,
69-
TensorDictTokenizer,
70-
TokenizedDatasetLoader,
71-
)
7272
from .tensor_specs import (
7373
Binary,
7474
BinaryDiscreteTensorSpec,
File renamed without changes.

torchrl/data/rlhf/dataset.py renamed to torchrl/data/llm/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TokenizedDatasetLoader:
3131
max_length (int): the maximum sequence length.
3232
dataset_name (str): the name of the dataset.
3333
tokenizer_fn (callable): the tokeinizing method constructor, such as
34-
:class:`torchrl.data.rlhf.TensorDictTokenizer`. When called,
34+
:class:`torchrl.data.llm.TensorDictTokenizer`. When called,
3535
it should return a :class:`tensordict.TensorDict` instance
3636
or a dictionary-like structure with the tokenized data.
3737
pre_tokenization_hook (callable, optional): called on
@@ -62,8 +62,8 @@ class TokenizedDatasetLoader:
6262
The dataset will be stored in ``<root_dir>/<split>/<max_length>/``.
6363
6464
Examples:
65-
>>> from torchrl.data.rlhf import TensorDictTokenizer
66-
>>> from torchrl.data.rlhf.reward import pre_tokenization_hook
65+
>>> from torchrl.data.llm import TensorDictTokenizer
66+
>>> from torchrl.data.llm.reward import pre_tokenization_hook
6767
>>> split = "train"
6868
>>> max_length = 550
6969
>>> dataset_name = "CarperAI/openai_summarize_comparisons"
@@ -359,7 +359,7 @@ def get_dataloader(
359359
Defaults to ``max(os.cpu_count() // 2, 1)``.
360360
361361
Examples:
362-
>>> from torchrl.data.rlhf.reward import PairwiseDataset
362+
>>> from torchrl.data.llm.reward import PairwiseDataset
363363
>>> dataloader = get_dataloader(
364364
... batch_size=256, block_size=550, tensorclass_type=PairwiseDataset, device="cpu")
365365
>>> for d in dataloader:

torchrl/data/rlhf/prompt.py renamed to torchrl/data/llm/prompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from tensordict import tensorclass, TensorDict
99

10-
from torchrl.data.rlhf.dataset import TensorDictTokenizer, TokenizedDatasetLoader
10+
from torchrl.data.llm.dataset import TensorDictTokenizer, TokenizedDatasetLoader
1111

1212
DEFAULT_DATASET = "CarperAI/openai_summarize_tldr"
1313

torchrl/data/rlhf/reward.py renamed to torchrl/data/llm/reward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010
from tensordict import tensorclass
11-
from torchrl.data.rlhf.dataset import TensorDictTokenizer, TokenizedDatasetLoader
11+
from torchrl.data.llm.dataset import TensorDictTokenizer, TokenizedDatasetLoader
1212

1313
DEFAULT_DATASET = "CarperAI/openai_summarize_comparisons"
1414
_has_datasets = importlib.util.find_spec("datasets") is not None

torchrl/data/rlhf/utils.py renamed to torchrl/data/llm/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch import nn, Tensor
1515
from torch.nn import functional as F
1616

17-
from torchrl.data.rlhf.prompt import PromptData
17+
from torchrl.data.llm.prompt import PromptData
1818

1919
_has_transformers = importlib.util.find_spec("transformers") is not None
2020

@@ -154,10 +154,10 @@ class RolloutFromModel:
154154
155155
Examples:
156156
>>> from tensordict.nn import TensorDictModule
157-
>>> from torchrl.modules.models.rlhf import GPT2RewardModel
158-
>>> from torchrl.data.rlhf.utils import RolloutFromModel
159-
>>> from torchrl.data.rlhf.dataset import get_dataloader
160-
>>> from torchrl.data.rlhf.prompt import PromptData
157+
>>> from torchrl.modules.models.llm import GPT2RewardModel
158+
>>> from torchrl.data.llm.utils import RolloutFromModel
159+
>>> from torchrl.data.llm.dataset import get_dataloader
160+
>>> from torchrl.data.llm.prompt import PromptData
161161
>>> from transformers import GPT2LMHeadModel
162162
>>>
163163
>>> dl = get_dataloader(

torchrl/data/rlhf.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import warnings
6+
7+
from torchrl.data.llm import (
8+
AdaptiveKLController,
9+
ConstantKLController,
10+
create_infinite_iterator,
11+
get_dataloader,
12+
PairwiseDataset,
13+
PromptData,
14+
PromptTensorDictTokenizer,
15+
RewardData,
16+
RolloutFromModel,
17+
TensorDictTokenizer,
18+
TokenizedDatasetLoader,
19+
)
20+
21+
__all__ = [
22+
"create_infinite_iterator",
23+
"get_dataloader",
24+
"TensorDictTokenizer",
25+
"TokenizedDatasetLoader",
26+
"PromptData",
27+
"PromptTensorDictTokenizer",
28+
"PairwiseDataset",
29+
"RewardData",
30+
"AdaptiveKLController",
31+
"ConstantKLController",
32+
"RolloutFromModel",
33+
]
34+
35+
warnings.warn(
36+
"Imports from torchrl.data.rlhf have moved to torchrl.data.llm. "
37+
"torchrl.data.rlhf will be deprecated in v0.10.",
38+
category=DeprecationWarning,
39+
)

torchrl/envs/transforms/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from .gym_transforms import EndOfLifeTransform
7-
from .r3m import R3MTransform
8-
from .rb_transforms import MultiStepTransform
9-
from .rlhf import (
7+
from .llm import (
108
as_nested_tensor,
119
as_padded_tensor,
1210
DataLoadingPrimer,
1311
KLRewardTransform,
1412
)
13+
from .r3m import R3MTransform
14+
from .rb_transforms import MultiStepTransform
1515
from .transforms import (
1616
ActionDiscretizer,
1717
ActionMask,

0 commit comments

Comments
 (0)