Skip to content

Commit 1ba8c84

Browse files
author
Vincent Moens
committed
[Feature] LLM Tooling
ghstack-source-id: 2eb02d4 Pull-Request-resolved: #2966
1 parent b1d2dc2 commit 1ba8c84

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+12079
-7
lines changed

.github/unittest/linux/scripts/run_all.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,11 +208,13 @@ pytest test/smoke_test_deps.py -v --durations 200 -k 'test_gym or test_dm_contro
208208
if [ "${CU_VERSION:-}" != cpu ] ; then
209209
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
210210
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
211+
--ignore test/llm \
211212
--timeout=120 --mp_fork_if_no_cuda
212213
else
213214
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
214215
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
215216
--ignore test/test_distributed.py \
217+
--ignore test/llm \
216218
--timeout=120 --mp_fork_if_no_cuda
217219
fi
218220

.github/unittest/linux_optdeps/scripts/run_all.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ export BATCHED_PIPE_TIMEOUT=60
159159
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
160160
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
161161
--ignore test/test_distributed.py \
162+
--ignore test/llm \
162163
--timeout=120 --mp_fork_if_no_cuda
163164

164165
coverage combine

docs/source/reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ API Reference
77
collectors
88
data
99
envs
10+
llms
1011
modules
1112
objectives
1213
trainers

docs/source/reference/llms.rst

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.. currentmodule:: torchrl.trainers
1+
.. currentmodule:: torchrl
22

33
LLM interface
44
=============
@@ -7,13 +7,125 @@ LLM interface
77

88
TorchRL offers a set of tools for LLM post-training, as well as some examples for training or setup.
99

10+
Collectors
11+
----------
12+
13+
TorchRL offers a specialized collector class (:class:`~torchrl.collectors.llm.LLMCollector`) that is tailored for LLM
14+
use cases. We also provide dedicated updaters for some inference engines.
15+
16+
.. currentmodule:: torchrl.collectors.llm
17+
18+
.. autosummary::
19+
:toctree: generated/
20+
:template: rl_template.rst
21+
22+
vLLMUpdater
23+
LLMCollector
24+
25+
1026
Data structures
1127
---------------
1228

29+
To handle text-based data structures (such as conversations etc.), we offer a few data structures dedicated to carrying
30+
data for LLM post-training.
31+
1332
.. currentmodule:: torchrl.data.llm
1433

1534
.. autosummary::
1635
:toctree: generated/
1736
:template: rl_template.rst
1837

1938
History
39+
LLMData
40+
41+
Environments
42+
------------
43+
44+
When fine-tuning an LLM using TorchRL, the environment is a crucial component of the inference pipeline, alongside the
45+
policy and collector. Environments manage operations that are not handled by the LLM itself, such as interacting with
46+
tools, loading prompts from datasets, computing rewards (when necessary), and formatting data.
47+
48+
The design of environments in TorchRL allows for flexibility and modularity. By framing tasks as environments, users can
49+
easily extend or modify existing environments using transforms. This approach enables the isolation of individual
50+
components within specific :class:`~torchrl.envs.EnvBase` or :class:`~torchrl.envs.Transform` subclasses, making it
51+
simpler to augment or alter the environment logic.
52+
53+
Available Environment Classes and Utilities
54+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
55+
56+
TorchRL provides various environment classes and utilities for working with LLMs, including:
57+
58+
- Various environment classes (:class:`~torchrl.envs.llm.ChatEnv`, :class:`~torchrl.envs.llm.DatasetChatEnv`,
59+
:class:`~torchrl.envs.llm.GSM8KEnv`, etc.)
60+
- Utility functions (:class:`~torchrl.envs.make_gsm8k_env`, :class:`~torchrl.envs.make_mlgym`, etc.)
61+
- Transforms and other supporting classes (:class:`~torchrl.envs.KLRewardTransform`,
62+
:class:`~torchrl.envs.TemplateTransform`, :class:`~torchrl.envs.Tokenizer`, etc.)
63+
64+
These components can be used to create customized environments tailored to specific use cases and requirements.
65+
66+
.. currentmodule:: torchrl.envs.llm
67+
68+
.. autosummary::
69+
:toctree: generated/
70+
:template: rl_template.rst
71+
72+
ChatEnv
73+
DatasetChatEnv
74+
GSM8KEnv
75+
make_gsm8k_env
76+
GSM8KPrepareQuestion
77+
GSM8KEnv
78+
IFEvalEnv
79+
IfEvalScorer
80+
IFEvalScoreData
81+
LLMEnv
82+
LLMHashingEnv
83+
make_mlgym
84+
MLGymWrapper
85+
GSM8KRewardParser
86+
IfEvalScorer
87+
as_nested_tensor
88+
as_padded_tensor
89+
DataLoadingPrimer
90+
KLRewardTransform
91+
TemplateTransform
92+
Tokenizer
93+
94+
Modules
95+
-------
96+
97+
The :ref:`~torchrl.modules.llm` section provides a set of wrappers and utility functions for popular training and
98+
inference backends. The main goal of these primitives is to:
99+
100+
- Unify the input / output data format across training and inference pipelines;
101+
- Unify the input / output data format across backends (to be able to use different backends across losses and
102+
collectors, for instance)
103+
- Give appropriate tooling to construct these objects in typical RL settings (resource allocation, async execution,
104+
weight update, etc.)
105+
106+
Wrappers
107+
~~~~~~~~
108+
109+
.. currentmodule:: torchrl.modules.llm
110+
111+
.. autosummary::
112+
:toctree: generated/
113+
:template: rl_template.rst
114+
115+
TransformersWrapper
116+
vLLMWrapper
117+
118+
Utils
119+
~~~~~
120+
121+
.. currentmodule:: torchrl.modules.llm
122+
123+
.. autosummary::
124+
:toctree: generated/
125+
:template: rl_template.rst
126+
127+
CategoricalSequential
128+
LLMOnDevice
129+
make_vllm_worker
130+
stateless_init_process_group
131+
vLLMWorker

test/llm/libs/test_mlgym.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import argparse
8+
9+
from functools import partial
10+
11+
import pytest
12+
13+
from torchrl import logger as torchrl_logger
14+
from torchrl.envs import SerialEnv
15+
16+
from torchrl.envs.llm import make_mlgym
17+
from torchrl.modules.llm import TransformersWrapper
18+
19+
20+
class TestMLGYM:
21+
def test_mlgym_specs(self):
22+
from transformers import AutoModelForCausalLM, AutoTokenizer
23+
24+
model_name = "Qwen/Qwen2.5-7B-Instruct"
25+
tokenizer = AutoTokenizer.from_pretrained(model_name)
26+
tokenizer.eos_token = "<|im_end|>"
27+
policy = TransformersWrapper(
28+
AutoModelForCausalLM.from_pretrained(model_name).cuda(),
29+
tokenizer=tokenizer,
30+
from_text=True,
31+
generate=True,
32+
device="cuda:0",
33+
generate_kwargs={
34+
# "temperature": 0.8,
35+
# "repetition_penalty": 1.5,
36+
"max_new_tokens": 1024
37+
},
38+
)
39+
40+
env = SerialEnv(
41+
1,
42+
[
43+
partial(
44+
make_mlgym,
45+
task="prisonersDilemma",
46+
tokenizer=tokenizer,
47+
device="cuda:0",
48+
)
49+
],
50+
)
51+
rollout = env.rollout(3, policy)
52+
torchrl_logger.info(f"{rollout=}")
53+
env.check_env_specs(break_when_any_done="both")
54+
55+
def test_mlgym_task_reset(self):
56+
from transformers import AutoModelForCausalLM, AutoTokenizer
57+
58+
model_name = "Qwen/Qwen2.5-7B-Instruct"
59+
tokenizer = AutoTokenizer.from_pretrained(model_name)
60+
tokenizer.eos_token = "<|im_end|>"
61+
policy = TransformersWrapper(
62+
AutoModelForCausalLM.from_pretrained(model_name).cuda(),
63+
tokenizer=tokenizer,
64+
from_text=True,
65+
generate=True,
66+
device="cuda:0",
67+
generate_kwargs={
68+
# "temperature": 0.8,
69+
# "repetition_penalty": 1.5,
70+
"max_new_tokens": 1024
71+
},
72+
)
73+
74+
env = SerialEnv(
75+
1,
76+
[
77+
partial(
78+
make_mlgym,
79+
tasks=[
80+
"prisonersDilemma",
81+
"regressionKaggleHousePrice",
82+
"battleOfSexes",
83+
],
84+
tokenizer=tokenizer,
85+
device="cuda:0",
86+
)
87+
],
88+
)
89+
# We should get at least two tasks
90+
rollout = env.rollout(100, policy, break_when_any_done=False)
91+
torchrl_logger.info(f"{rollout=}")
92+
torchrl_logger.info(rollout["task"])
93+
94+
def test_mlgym_wrong_format(self):
95+
# A vanilla policy will not output anything useful, yet the env should run without error
96+
...
97+
98+
99+
if __name__ == "__main__":
100+
args, unknown = argparse.ArgumentParser().parse_known_args()
101+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/llm/mocking_classes.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import random
8+
import string
9+
10+
import torch
11+
12+
13+
class DummyStrDataLoader:
14+
def __init__(self, batch_size=0):
15+
if isinstance(batch_size, tuple):
16+
batch_size = torch.Size(batch_size).numel()
17+
self.batch_size = batch_size
18+
19+
def generate_random_string(self, length=10):
20+
"""Generate a random string of a given length."""
21+
return "".join(random.choice(string.ascii_lowercase) for _ in range(length))
22+
23+
def __iter__(self):
24+
return self
25+
26+
def __next__(self):
27+
if self.batch_size == 0:
28+
return {"text": self.generate_random_string()}
29+
else:
30+
return {
31+
"text": [self.generate_random_string() for _ in range(self.batch_size)]
32+
}
33+
34+
35+
class DummyTensorDataLoader:
36+
def __init__(self, batch_size=0, max_length=10, padding=False):
37+
if isinstance(batch_size, tuple):
38+
batch_size = torch.Size(batch_size).numel()
39+
self.batch_size = batch_size
40+
self.max_length = max_length
41+
self.padding = padding
42+
43+
def generate_random_tensor(self):
44+
"""Generate a tensor of random int64 values."""
45+
length = random.randint(1, self.max_length)
46+
rt = torch.randint(1, 10000, (length,))
47+
return rt
48+
49+
def pad_tensor(self, tensor):
50+
"""Pad a tensor to the maximum length."""
51+
padding_length = self.max_length - len(tensor)
52+
return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor))
53+
54+
def __iter__(self):
55+
return self
56+
57+
def __next__(self):
58+
if self.batch_size == 0:
59+
tensor = self.generate_random_tensor()
60+
tokens = self.pad_tensor(tensor) if self.padding else tensor
61+
else:
62+
tensors = [self.generate_random_tensor() for _ in range(self.batch_size)]
63+
if self.padding:
64+
tensors = [self.pad_tensor(tensor) for tensor in tensors]
65+
tokens = torch.stack(tensors)
66+
else:
67+
tokens = tensors
68+
return {"tokens": tokens, "attention_mask": tokens != 0}

test/llm/smoke_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import argparse
8+
9+
import pytest
10+
11+
12+
def test_import():
13+
pass
14+
15+
16+
if __name__ == "__main__":
17+
args, unknown = argparse.ArgumentParser().parse_known_args()
18+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/llm/smoke_test_deps.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import argparse
8+
9+
import pytest
10+
11+
12+
if __name__ == "__main__":
13+
args, unknown = argparse.ArgumentParser().parse_known_args()
14+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)