Skip to content

Commit 4135a83

Browse files
authored
[Feature] LLM collector (#2879)
1 parent fd10fe2 commit 4135a83

File tree

4 files changed

+321
-1
lines changed

4 files changed

+321
-1
lines changed

docs/source/reference/collectors.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,21 @@ node or across multiple nodes.
319319
submitit_delayed_launcher
320320
RayCollector
321321

322+
LLM Collectors
323+
---------------------------
324+
TorchRL also provides a data collectors for large language models. These collectors
325+
are meant to include a subset of the functionality of other data collectors, targeted
326+
at supporting researchers in fine-tuning large language models. These classes
327+
currently derive from the :class:`~torchrl.collectors.SyncDataCollector` class.
328+
These classes are experimental and subject to change.
329+
330+
.. currentmodule:: torchrl.collectors.llm_collectors
331+
332+
.. autosummary::
333+
:toctree: generated/
334+
:template: rl_template.rst
335+
336+
LLMCollector
322337

323338
Helper functions
324339
----------------

test/test_collector.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import contextlib
99
import functools
1010
import gc
11+
12+
import importlib
1113
import os
1214
import subprocess
1315
import sys
@@ -49,21 +51,26 @@
4951
MultiaSyncDataCollector,
5052
MultiSyncDataCollector,
5153
)
54+
55+
from torchrl.collectors.llm_collector import LLMCollector
5256
from torchrl.collectors.utils import split_trajectories
5357
from torchrl.data import (
5458
Composite,
5559
LazyMemmapStorage,
60+
LazyStackStorage,
5661
LazyTensorStorage,
5762
NonTensor,
5863
ReplayBuffer,
5964
TensorSpec,
6065
Unbounded,
6166
)
67+
from torchrl.data.llm.dataset import _has_transformers
6268
from torchrl.data.utils import CloudpickleWrapper
6369
from torchrl.envs import (
6470
EnvBase,
6571
EnvCreator,
6672
InitTracker,
73+
LLMEnv,
6774
ParallelEnv,
6875
SerialEnv,
6976
StepCounter,
@@ -77,7 +84,13 @@
7784
PARTIAL_MISSING_ERR,
7885
RandomPolicy,
7986
)
80-
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule
87+
from torchrl.modules import (
88+
Actor,
89+
OrnsteinUhlenbeckProcessModule,
90+
SafeModule,
91+
TransformersWrapper,
92+
vLLMWrapper,
93+
)
8194

8295
if os.getenv("PYTORCH_TEST_FBCODE"):
8396
IS_FB = True
@@ -102,6 +115,7 @@
102115
DiscreteActionConvPolicy,
103116
DiscreteActionVecMockEnv,
104117
DiscreteActionVecPolicy,
118+
DummyStrDataLoader,
105119
EnvThatErrorsAfter10Iters,
106120
EnvWithDynamicSpec,
107121
HeterogeneousCountingEnv,
@@ -134,6 +148,7 @@
134148
DiscreteActionConvPolicy,
135149
DiscreteActionVecMockEnv,
136150
DiscreteActionVecPolicy,
151+
DummyStrDataLoader,
137152
EnvThatErrorsAfter10Iters,
138153
EnvWithDynamicSpec,
139154
HeterogeneousCountingEnv,
@@ -151,6 +166,7 @@
151166
PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7
152167
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
153168
_has_cuda = torch.cuda.is_available()
169+
_has_vllm = importlib.util.find_spec("vllm") is not None
154170

155171

156172
class WrappablePolicy(nn.Module):
@@ -3544,6 +3560,107 @@ def test_weight_update(self):
35443560
collector.shutdown()
35453561

35463562

3563+
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
3564+
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
3565+
class TestLLMCollector:
3566+
@pytest.fixture(scope="module")
3567+
def vllm_instance(self):
3568+
try:
3569+
import vllm
3570+
except ImportError:
3571+
pytest.skip(reason="missing vllm")
3572+
3573+
llm_model = vllm.LLM("gpt2")
3574+
tokenizer = llm_model.get_tokenizer()
3575+
tokenizer.pad_token = tokenizer.eos_token
3576+
return llm_model
3577+
3578+
@pytest.fixture(scope="module")
3579+
def transformers_instance(self):
3580+
from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel
3581+
3582+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
3583+
model = GPT2LMHeadModel(GPT2Config()).eval()
3584+
# tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
3585+
# model = OPTModel(OPTConfig("facebook/opt-125m"))
3586+
# tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
3587+
# model = OPTForCausalLM(OPTConfig())
3588+
3589+
tokenizer.pad_token = tokenizer.eos_token
3590+
tokenizer.padding_side = "left"
3591+
3592+
return model, tokenizer
3593+
3594+
@pytest.mark.slow
3595+
@pytest.mark.parametrize("rb", [True, False])
3596+
@pytest.mark.parametrize("total_steps", [1, 10, 20])
3597+
def test_llm_collector_with_vllm(self, rb, total_steps, vllm_instance):
3598+
# NOTE: if VLLM fails with CUDA multiprocessing, try setting
3599+
# `export VLLM_WORKER_MULTIPROC_METHOD=spawn`
3600+
policy = vLLMWrapper(vllm_instance)
3601+
tokenizer = vllm_instance.get_tokenizer()
3602+
self._run_collector_test(total_steps, rb, policy, tokenizer)
3603+
3604+
@pytest.mark.slow
3605+
@pytest.mark.parametrize("rb", [True, False])
3606+
@pytest.mark.parametrize("total_steps", [1, 10, 20])
3607+
def test_llm_collector_with_transformers(
3608+
self, rb, total_steps, transformers_instance
3609+
):
3610+
model, tokenizer = transformers_instance
3611+
policy = TransformersWrapper(
3612+
model,
3613+
tokenizer=tokenizer,
3614+
from_text=True,
3615+
generate=True,
3616+
return_log_probs=True,
3617+
)
3618+
self._run_collector_test(total_steps, rb, policy, tokenizer)
3619+
3620+
def _run_collector_test(self, total_steps, rb, policy, tokenizer):
3621+
bsz = 1
3622+
dataloader = DummyStrDataLoader(bsz)
3623+
3624+
env = LLMEnv.from_dataloader(
3625+
dataloader=dataloader,
3626+
tokenizer=tokenizer,
3627+
str2str=True,
3628+
batch_size=bsz,
3629+
group_repeats=True,
3630+
)
3631+
if rb:
3632+
rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2))
3633+
else:
3634+
rb = None
3635+
collector = LLMCollector(
3636+
env=env,
3637+
policy_factory=lambda: policy,
3638+
steps_per_batch=env.batch_size[0],
3639+
replay_buffer=rb,
3640+
total_steps=total_steps,
3641+
)
3642+
3643+
stack = []
3644+
for data in collector:
3645+
# Should be moved to replay buffer
3646+
if rb is not None:
3647+
assert data is None
3648+
else:
3649+
stack.append(data)
3650+
3651+
if rb is not None:
3652+
# Now check the buffer
3653+
assert len(rb) == total_steps
3654+
sample = rb.sample(1)
3655+
# Should match length
3656+
assert len(sample["text"]) == 1
3657+
# Should be non-empty
3658+
assert sample["text_response"] is not None
3659+
else:
3660+
stack = torch.cat(stack)
3661+
assert stack.numel() == total_steps
3662+
3663+
35473664
if __name__ == "__main__":
35483665
args, unknown = argparse.ArgumentParser().parse_known_args()
35493666
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_cost.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import argparse
88
import contextlib
99
import functools
10+
import importlib.util
1011
import itertools
1112
import operator
1213
import os
@@ -169,6 +170,8 @@
169170
_has_functorch = False
170171
FUNCTORCH_ERR = str(err)
171172

173+
_has_transformers = bool(importlib.util.find_spec("transformers"))
174+
172175
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
173176
IS_WINDOWS = sys.platform == "win32"
174177

@@ -7998,6 +8001,7 @@ def test_dcql_reduction(self, reduction):
79988001
assert loss[key].shape == torch.Size([])
79998002

80008003

8004+
@pytest.mark.skipif(not _has_transformers, reason="requires transformers lib")
80018005
class TestPPO(LossModuleTestBase):
80028006
seed = 0
80038007

0 commit comments

Comments
 (0)