Skip to content

Commit 4136fb1

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents 33bcade + e67e06b commit 4136fb1

File tree

27 files changed

+1057
-875
lines changed

27 files changed

+1057
-875
lines changed

docs/source/reference/data.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,9 @@ efficient sampling.
11331133
get_dataloader
11341134
ConstantKLController
11351135
AdaptiveKLController
1136+
LLMData
1137+
LLMInput
1138+
LLMOutput
11361139

11371140

11381141
Utils

examples/rlhf/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from torchrl.data.rlhf.prompt import get_prompt_dataloader_tldr
1+
from torchrl.data.llm.prompt import get_prompt_dataloader_tldr
22

33
__all__ = ["get_prompt_dataloader_tldr"]

examples/rlhf/models/reward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tensordict.nn import TensorDictModule
99
from torchrl._utils import logger as torchrl_logger
1010

11-
from torchrl.modules.models.rlhf import GPT2RewardModel
11+
from torchrl.modules.models.llm import GPT2RewardModel
1212

1313

1414
def init_reward_model(

examples/rlhf/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from torch.optim.lr_scheduler import CosineAnnealingLR
1818
from torchrl._utils import logger as torchrl_logger
1919

20-
from torchrl.data.rlhf.dataset import get_dataloader
21-
from torchrl.data.rlhf.prompt import PromptData
20+
from torchrl.data.llm.dataset import get_dataloader
21+
from torchrl.data.llm.prompt import PromptData
2222
from utils import get_file_logger, resolve_name_or_path, setup
2323

2424

examples/rlhf/train_reward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from models.reward import init_reward_model
1010
from torch.optim.lr_scheduler import CosineAnnealingLR
1111
from torchrl._utils import logger as torchrl_logger
12-
from torchrl.data.rlhf.dataset import get_dataloader
13-
from torchrl.data.rlhf.reward import PairwiseDataset
12+
from torchrl.data.llm.dataset import get_dataloader
13+
from torchrl.data.llm.reward import PairwiseDataset
1414
from utils import get_file_logger, resolve_name_or_path, setup
1515

1616

examples/rlhf/train_rlhf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import hydra
77
import torch
88
from models.actor_critic import init_actor_critic
9-
from torchrl.data.rlhf.utils import AdaptiveKLController, RolloutFromModel
9+
from torchrl.data.llm.utils import AdaptiveKLController, RolloutFromModel
1010

1111
from torchrl.record.loggers import get_logger
1212

examples/rlhf/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
TensorDictReplayBuffer,
2323
TensorStorage,
2424
)
25+
from torchrl.data.llm.dataset import get_dataloader
26+
from torchrl.data.llm.prompt import PromptData
2527
from torchrl.data.replay_buffers import SamplerWithoutReplacement
26-
from torchrl.data.rlhf.dataset import get_dataloader
27-
from torchrl.data.rlhf.prompt import PromptData
2828
from torchrl.objectives import ClipPPOLoss
2929
from torchrl.objectives.value import GAE
3030

test/assets/generate.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
"""Script used to generate the mini datasets."""
77
import multiprocessing as mp
8+
import pathlib
89

910
try:
1011
mp.set_start_method("spawn")
@@ -14,8 +15,8 @@
1415

1516
from datasets import Dataset, DatasetDict, load_dataset
1617

17-
from torchrl.data.rlhf.dataset import get_dataloader
18-
from torchrl.data.rlhf.prompt import PromptData
18+
from torchrl.data.llm.dataset import get_dataloader
19+
from torchrl.data.llm.prompt import PromptData
1920

2021

2122
def generate_small_dataset(comparison=True):
@@ -42,7 +43,7 @@ def get_minibatch():
4243
batch_size=16,
4344
block_size=33,
4445
tensorclass_type=PromptData,
45-
dataset_name="../datasets_mini/openai_summarize_tldr",
46+
dataset_name=f"{pathlib.Path(__file__).parent}/../datasets_mini/openai_summarize_tldr",
4647
device="cpu",
4748
num_workers=2,
4849
infinite=False,
@@ -52,9 +53,12 @@ def get_minibatch():
5253
root_dir=tmpdir,
5354
)
5455
for data in dl:
55-
data = data.clone().memmap_("test/datasets_mini/tldr_batch/")
56+
data = data.clone().memmap_(
57+
f"{pathlib.Path(__file__).parent}/../datasets_mini/tldr_batch/"
58+
)
5659
break
5760

5861

5962
if __name__ == "__main__":
63+
generate_small_dataset(False)
6064
get_minibatch()

test/assets/tldr_batch.zip

2 Bytes
Binary file not shown.

test/test_actors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from torch import distributions as dist, nn
1616
from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot
17-
from torchrl.data.rlhf.dataset import _has_transformers
17+
from torchrl.data.llm.dataset import _has_transformers
1818
from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal
1919
from torchrl.modules.tensordict_module.actors import (
2020
_process_action_space_spec,

0 commit comments

Comments
 (0)