|
8 | 8 | import contextlib
|
9 | 9 | import functools
|
10 | 10 | import gc
|
| 11 | + |
| 12 | +import importlib |
11 | 13 | import os
|
12 | 14 | import subprocess
|
13 | 15 | import sys
|
|
49 | 51 | MultiaSyncDataCollector,
|
50 | 52 | MultiSyncDataCollector,
|
51 | 53 | )
|
| 54 | + |
| 55 | +from torchrl.collectors.llm_collector import LLMCollector |
52 | 56 | from torchrl.collectors.utils import split_trajectories
|
53 | 57 | from torchrl.data import (
|
54 | 58 | Composite,
|
55 | 59 | LazyMemmapStorage,
|
| 60 | + LazyStackStorage, |
56 | 61 | LazyTensorStorage,
|
57 | 62 | NonTensor,
|
58 | 63 | ReplayBuffer,
|
59 | 64 | TensorSpec,
|
60 | 65 | Unbounded,
|
61 | 66 | )
|
| 67 | +from torchrl.data.llm.dataset import _has_transformers |
62 | 68 | from torchrl.data.utils import CloudpickleWrapper
|
63 | 69 | from torchrl.envs import (
|
64 | 70 | EnvBase,
|
65 | 71 | EnvCreator,
|
66 | 72 | InitTracker,
|
| 73 | + LLMEnv, |
67 | 74 | ParallelEnv,
|
68 | 75 | SerialEnv,
|
69 | 76 | StepCounter,
|
|
77 | 84 | PARTIAL_MISSING_ERR,
|
78 | 85 | RandomPolicy,
|
79 | 86 | )
|
80 |
| -from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule |
| 87 | +from torchrl.modules import ( |
| 88 | + Actor, |
| 89 | + OrnsteinUhlenbeckProcessModule, |
| 90 | + SafeModule, |
| 91 | + TransformersWrapper, |
| 92 | + vLLMWrapper, |
| 93 | +) |
81 | 94 |
|
82 | 95 | if os.getenv("PYTORCH_TEST_FBCODE"):
|
83 | 96 | IS_FB = True
|
|
102 | 115 | DiscreteActionConvPolicy,
|
103 | 116 | DiscreteActionVecMockEnv,
|
104 | 117 | DiscreteActionVecPolicy,
|
| 118 | + DummyStrDataLoader, |
105 | 119 | EnvThatErrorsAfter10Iters,
|
106 | 120 | EnvWithDynamicSpec,
|
107 | 121 | HeterogeneousCountingEnv,
|
|
134 | 148 | DiscreteActionConvPolicy,
|
135 | 149 | DiscreteActionVecMockEnv,
|
136 | 150 | DiscreteActionVecPolicy,
|
| 151 | + DummyStrDataLoader, |
137 | 152 | EnvThatErrorsAfter10Iters,
|
138 | 153 | EnvWithDynamicSpec,
|
139 | 154 | HeterogeneousCountingEnv,
|
|
151 | 166 | PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7
|
152 | 167 | TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
|
153 | 168 | _has_cuda = torch.cuda.is_available()
|
| 169 | +_has_vllm = importlib.util.find_spec("vllm") is not None |
154 | 170 |
|
155 | 171 |
|
156 | 172 | class WrappablePolicy(nn.Module):
|
@@ -3544,6 +3560,107 @@ def test_weight_update(self):
|
3544 | 3560 | collector.shutdown()
|
3545 | 3561 |
|
3546 | 3562 |
|
| 3563 | +@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies") |
| 3564 | +@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies") |
| 3565 | +class TestLLMCollector: |
| 3566 | + @pytest.fixture(scope="module") |
| 3567 | + def vllm_instance(self): |
| 3568 | + try: |
| 3569 | + import vllm |
| 3570 | + except ImportError: |
| 3571 | + pytest.skip(reason="missing vllm") |
| 3572 | + |
| 3573 | + llm_model = vllm.LLM("gpt2") |
| 3574 | + tokenizer = llm_model.get_tokenizer() |
| 3575 | + tokenizer.pad_token = tokenizer.eos_token |
| 3576 | + return llm_model |
| 3577 | + |
| 3578 | + @pytest.fixture(scope="module") |
| 3579 | + def transformers_instance(self): |
| 3580 | + from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel |
| 3581 | + |
| 3582 | + tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| 3583 | + model = GPT2LMHeadModel(GPT2Config()).eval() |
| 3584 | + # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") |
| 3585 | + # model = OPTModel(OPTConfig("facebook/opt-125m")) |
| 3586 | + # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") |
| 3587 | + # model = OPTForCausalLM(OPTConfig()) |
| 3588 | + |
| 3589 | + tokenizer.pad_token = tokenizer.eos_token |
| 3590 | + tokenizer.padding_side = "left" |
| 3591 | + |
| 3592 | + return model, tokenizer |
| 3593 | + |
| 3594 | + @pytest.mark.slow |
| 3595 | + @pytest.mark.parametrize("rb", [True, False]) |
| 3596 | + @pytest.mark.parametrize("total_steps", [1, 10, 20]) |
| 3597 | + def test_llm_collector_with_vllm(self, rb, total_steps, vllm_instance): |
| 3598 | + # NOTE: if VLLM fails with CUDA multiprocessing, try setting |
| 3599 | + # `export VLLM_WORKER_MULTIPROC_METHOD=spawn` |
| 3600 | + policy = vLLMWrapper(vllm_instance) |
| 3601 | + tokenizer = vllm_instance.get_tokenizer() |
| 3602 | + self._run_collector_test(total_steps, rb, policy, tokenizer) |
| 3603 | + |
| 3604 | + @pytest.mark.slow |
| 3605 | + @pytest.mark.parametrize("rb", [True, False]) |
| 3606 | + @pytest.mark.parametrize("total_steps", [1, 10, 20]) |
| 3607 | + def test_llm_collector_with_transformers( |
| 3608 | + self, rb, total_steps, transformers_instance |
| 3609 | + ): |
| 3610 | + model, tokenizer = transformers_instance |
| 3611 | + policy = TransformersWrapper( |
| 3612 | + model, |
| 3613 | + tokenizer=tokenizer, |
| 3614 | + from_text=True, |
| 3615 | + generate=True, |
| 3616 | + return_log_probs=True, |
| 3617 | + ) |
| 3618 | + self._run_collector_test(total_steps, rb, policy, tokenizer) |
| 3619 | + |
| 3620 | + def _run_collector_test(self, total_steps, rb, policy, tokenizer): |
| 3621 | + bsz = 1 |
| 3622 | + dataloader = DummyStrDataLoader(bsz) |
| 3623 | + |
| 3624 | + env = LLMEnv.from_dataloader( |
| 3625 | + dataloader=dataloader, |
| 3626 | + tokenizer=tokenizer, |
| 3627 | + str2str=True, |
| 3628 | + batch_size=bsz, |
| 3629 | + group_repeats=True, |
| 3630 | + ) |
| 3631 | + if rb: |
| 3632 | + rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2)) |
| 3633 | + else: |
| 3634 | + rb = None |
| 3635 | + collector = LLMCollector( |
| 3636 | + env=env, |
| 3637 | + policy_factory=lambda: policy, |
| 3638 | + steps_per_batch=env.batch_size[0], |
| 3639 | + replay_buffer=rb, |
| 3640 | + total_steps=total_steps, |
| 3641 | + ) |
| 3642 | + |
| 3643 | + stack = [] |
| 3644 | + for data in collector: |
| 3645 | + # Should be moved to replay buffer |
| 3646 | + if rb is not None: |
| 3647 | + assert data is None |
| 3648 | + else: |
| 3649 | + stack.append(data) |
| 3650 | + |
| 3651 | + if rb is not None: |
| 3652 | + # Now check the buffer |
| 3653 | + assert len(rb) == total_steps |
| 3654 | + sample = rb.sample(1) |
| 3655 | + # Should match length |
| 3656 | + assert len(sample["text"]) == 1 |
| 3657 | + # Should be non-empty |
| 3658 | + assert sample["text_response"] is not None |
| 3659 | + else: |
| 3660 | + stack = torch.cat(stack) |
| 3661 | + assert stack.numel() == total_steps |
| 3662 | + |
| 3663 | + |
3547 | 3664 | if __name__ == "__main__":
|
3548 | 3665 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
3549 | 3666 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments