diff --git a/.github/unittest/linux_libs/scripts_rlhf/environment.yml b/.github/unittest/linux_libs/scripts_llm/environment.yml similarity index 91% rename from .github/unittest/linux_libs/scripts_rlhf/environment.yml rename to .github/unittest/linux_libs/scripts_llm/environment.yml index 6b8800a2531..b0897796779 100644 --- a/.github/unittest/linux_libs/scripts_rlhf/environment.yml +++ b/.github/unittest/linux_libs/scripts_llm/environment.yml @@ -17,5 +17,6 @@ dependencies: - pyyaml - scipy - hydra-core - - transformers<4.42.0 + - transformers - datasets + - vllm diff --git a/.github/unittest/linux_libs/scripts_rlhf/install.sh b/.github/unittest/linux_libs/scripts_llm/install.sh similarity index 59% rename from .github/unittest/linux_libs/scripts_rlhf/install.sh rename to .github/unittest/linux_libs/scripts_llm/install.sh index 1e927e08df6..68e63bf58ca 100755 --- a/.github/unittest/linux_libs/scripts_rlhf/install.sh +++ b/.github/unittest/linux_libs/scripts_llm/install.sh @@ -26,23 +26,24 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with cu128" -if [[ "$TORCH_VERSION" == "nightly" ]]; then - if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U - else - pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U - fi -elif [[ "$TORCH_VERSION" == "stable" ]]; then - if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu - else - pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128 - fi -else - printf "Failed to install pytorch" - exit 1 -fi +# We skip pytorch install due to vllm requirements +#printf "Installing PyTorch with cu128" +#if [[ "$TORCH_VERSION" == "nightly" ]]; then +# if [ "${CU_VERSION:-}" == cpu ] ; then +# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U +# else +# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U +# fi +#elif [[ "$TORCH_VERSION" == "stable" ]]; then +# if [ "${CU_VERSION:-}" == cpu ] ; then +# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu +# else +# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128 +# fi +#else +# printf "Failed to install pytorch" +# exit 1 +#fi # install tensordict if [[ "$RELEASE" == 0 ]]; then diff --git a/.github/unittest/linux_libs/scripts_rlhf/post_process.sh b/.github/unittest/linux_libs/scripts_llm/post_process.sh similarity index 100% rename from .github/unittest/linux_libs/scripts_rlhf/post_process.sh rename to .github/unittest/linux_libs/scripts_llm/post_process.sh diff --git a/.github/unittest/linux_libs/scripts_rlhf/run-clang-format.py b/.github/unittest/linux_libs/scripts_llm/run-clang-format.py similarity index 100% rename from .github/unittest/linux_libs/scripts_rlhf/run-clang-format.py rename to .github/unittest/linux_libs/scripts_llm/run-clang-format.py diff --git a/.github/unittest/linux_libs/scripts_rlhf/run_test.sh b/.github/unittest/linux_libs/scripts_llm/run_test.sh similarity index 86% rename from .github/unittest/linux_libs/scripts_rlhf/run_test.sh rename to .github/unittest/linux_libs/scripts_llm/run_test.sh index dcfc686ade0..a563e2e329b 100755 --- a/.github/unittest/linux_libs/scripts_rlhf/run_test.sh +++ b/.github/unittest/linux_libs/scripts_llm/run_test.sh @@ -24,6 +24,8 @@ python -c "import transformers, datasets" python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_actors.py -k llm --instafail -v --durations 200 --capture no --error-for-skips --runslow + python .github/unittest/helpers/coverage_run_parallel.py examples/rlhf/train_rlhf.py \ sys.device=cuda:0 sys.ref_device=cuda:0 \ model.name_or_path=gpt2 train.max_epochs=2 \ diff --git a/.github/unittest/linux_libs/scripts_rlhf/setup_env.sh b/.github/unittest/linux_libs/scripts_llm/setup_env.sh similarity index 100% rename from .github/unittest/linux_libs/scripts_rlhf/setup_env.sh rename to .github/unittest/linux_libs/scripts_llm/setup_env.sh diff --git a/.github/workflows/test-linux-rlhf.yml b/.github/workflows/test-linux-llm.yml similarity index 83% rename from .github/workflows/test-linux-rlhf.yml rename to .github/workflows/test-linux-llm.yml index 994667b8a66..4de8b8165d9 100644 --- a/.github/workflows/test-linux-rlhf.yml +++ b/.github/workflows/test-linux-llm.yml @@ -1,4 +1,4 @@ -name: RLHF Tests on Linux +name: LLM Tests on Linux on: pull_request: @@ -50,7 +50,7 @@ jobs: export TF_CPP_MIN_LOG_LEVEL=0 export TD_GET_DEFAULTS_TO_NONE=1 - bash .github/unittest/linux_libs/scripts_rlhf/setup_env.sh - bash .github/unittest/linux_libs/scripts_rlhf/install.sh - bash .github/unittest/linux_libs/scripts_rlhf/run_test.sh - bash .github/unittest/linux_libs/scripts_rlhf/post_process.sh + bash .github/unittest/linux_libs/scripts_llm/setup_env.sh + bash .github/unittest/linux_libs/scripts_llm/install.sh + bash .github/unittest/linux_libs/scripts_llm/run_test.sh + bash .github/unittest/linux_libs/scripts_llm/post_process.sh diff --git a/examples/rlhf/models/actor_critic.py b/examples/rlhf/models/actor_critic.py index b5be188fbd9..93c4d285b3e 100644 --- a/examples/rlhf/models/actor_critic.py +++ b/examples/rlhf/models/actor_critic.py @@ -34,4 +34,4 @@ def init_actor_critic(model_cfg, sys_cfg): critic = model.get_value_operator() critic_head = model.get_value_head() - return actor, VmapModule(critic), critic_head, base_model + return actor, VmapModule(critic, mock=True), critic_head, base_model diff --git a/test/test_actors.py b/test/test_actors.py index 629da3cbf7d..bd0e3b5ed6f 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -3,19 +3,27 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import importlib.util import os import pytest import torch - from tensordict import NonTensorStack, TensorDict from tensordict.nn import CompositeDistribution, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import distributions as dist, nn from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot +from torchrl.data.llm import LLMData from torchrl.data.llm.dataset import _has_transformers -from torchrl.modules import from_hf_transformers, MLP, SafeModule, TanhDelta, TanhNormal +from torchrl.modules import ( + from_hf_transformers, + from_vllm, + MLP, + SafeModule, + TanhDelta, + TanhNormal, +) from torchrl.modules.tensordict_module.actors import ( _process_action_space_spec, ActorValueOperator, @@ -37,6 +45,8 @@ from _utils_internal import get_default_devices from mocking_classes import NestedCountingEnv +_has_vllm = importlib.util.find_spec("vllm") is not None + @pytest.mark.parametrize( "log_prob_key", @@ -908,52 +918,253 @@ def test_lmhead_actorvalueoperator(device): @pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies") -class TestTransformerActor: +@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies") +class TestLLMActor: @pytest.mark.parametrize( - "from_text, generate, tokens, attention_mask", + "from_text, generate, return_log_probs, tokens, attention_mask", [ - (True, True, None, None), - (True, False, None, None), + (True, True, True, None, None), + (True, True, False, None, None), + (True, False, None, None, None), + ( + False, + True, + True, + torch.randint(1024, (1, 10)), + torch.ones(1, 10, dtype=torch.int64), + ), + (False, True, True, torch.randint(1024, (1, 10)), None), ( False, True, + False, torch.randint(1024, (1, 10)), torch.ones(1, 10, dtype=torch.int64), ), - (False, True, torch.randint(1024, (1, 10)), None), + (False, True, False, torch.randint(1024, (1, 10)), None), ], ) - def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask): - from torchrl.data.llm import LLMData + def test_from_hf_transformers( + self, from_text, generate, return_log_probs, tokens, attention_mask + ): from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny" + # Load the model and tokenizer + # model = AutoModel.from_pretrained(model_name) + # tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer.pad_token = tokenizer.eos_token model = GPT2LMHeadModel(GPT2Config()) + + tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" + m = from_hf_transformers( - model, tokenizer=tokenizer, from_text=from_text, generate=generate + model, + tokenizer=tokenizer, + from_text=from_text, + generate=generate, + return_log_probs=return_log_probs, + ) + self._run_check( + m, + tokens, + attention_mask, + generate, + return_log_probs, + from_text, + has_logits=True, + ) + + @pytest.mark.parametrize( + "from_text, generate, return_log_probs, tokens, attention_mask", + [ + (True, True, True, None, None), + (True, True, False, None, None), + (True, False, None, None, None), + ( + False, + True, + True, + torch.randint(1024, (1, 10)), + torch.ones(1, 10, dtype=torch.int64), + ), + (False, True, True, torch.randint(1024, (1, 10)), None), + ( + False, + True, + False, + torch.randint(1024, (1, 10)), + torch.ones(1, 10, dtype=torch.int64), + ), + (False, True, False, torch.randint(1024, (1, 10)), None), + ], + ) + def test_from_vllm( + self, from_text, generate, return_log_probs, tokens, attention_mask + ): + from vllm import LLM + + model = LLM(model="facebook/opt-125m") + m = from_vllm( + model, + from_text=from_text, + generate=generate, + return_log_probs=return_log_probs, + ) + self._run_check( + m, + tokens, + attention_mask, + generate, + return_log_probs, + from_text, + has_logits=False, ) + + def _make_data( + self, + m, + tokens, + attention_mask, + generate, + from_text, + has_logits, + text_response=None, + tokens_response=None, + ): + lp_kwargs = {} if from_text: - tdin = LLMData(text=NonTensorStack("a text"), batch_size=1) + if not generate: + text_response = ( + NonTensorStack(" and another text that follows") + if text_response is None + else text_response + ) + if not isinstance(text_response, NonTensorStack): + if isinstance(text_response, list): + text_response = NonTensorStack(*text_response) + else: + text_response = NonTensorStack(text_response) + lp_kwargs.update({"text_response": text_response}) + tdin = LLMData(text=NonTensorStack("a text"), **lp_kwargs, batch_size=1) else: - tdin = LLMData(tokens=tokens, attention_mask=attention_mask, batch_size=1) + if not generate: + if tokens_response is None: + shape_response = tokens.shape + shape_response = shape_response[:-1] + (shape_response[-1] * 2,) + tokens_response = torch.randint(1024, shape_response) + lp_kwargs.update({"tokens_response": tokens_response}) + tdin = LLMData( + tokens=tokens, attention_mask=attention_mask, **lp_kwargs, batch_size=1 + ) + return tdin + + def _run_check( + self, + m, + tokens, + attention_mask, + generate, + return_log_probs, + from_text, + has_logits, + ): + tdin = self._make_data( + m, tokens, attention_mask, generate, from_text, has_logits + ) + if from_text and generate: + assert tdin.text_response is None + elif from_text and not generate: + assert tdin.text_response is not None + td = m(tdin) assert td is tdin assert isinstance(td, LLMData) if from_text and generate: assert td.text_response is not None + if generate and (attention_mask is not None or from_text): + assert td.attention_mask is not None, (generate, generate, from_text) else: - assert td.text_response is None - if attention_mask is not None or from_text: - assert td.attention_mask is not None - else: - assert td.attention_mask is None + assert td.attention_mask is None, (generate, from_text) if not generate: - assert td.text_response is None - assert td.tokens_response is None + # logprobs are computed on text response of tokens_response + assert td.text_response is not None or td.tokens_response is not None assert td.log_probs is not None - assert td.logits is not None + if has_logits: + assert td.logits is not None + if generate: + if return_log_probs: + assert td.log_probs is not None + assert td.log_probs.shape[-2] == td.tokens_response.shape[-1] + else: + assert td.log_probs is None + + # Test the shapes + assert td.tokens_response is not None, (generate, has_logits, from_text) + + # If from text and not generating, the tokens are not returned for now + if not (from_text and not generate): + assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1] + # The convention is that the response only has new tokens + assert ( + td.tokens_response[..., : td.tokens.shape[-1]] + != td.tokens[..., : td.tokens_response.shape[-1]] + ).any(), (generate, from_text) + + @pytest.mark.parametrize( + "from_text, tokens, attention_mask", + [ + (True, None, None), + ( + False, + torch.randint(1024, (1, 10)), + torch.ones(1, 10, dtype=torch.int64), + ), + (False, torch.randint(1024, (1, 10)), None), + ], + ) + def test_from_vllm_logprobs(self, from_text, tokens, attention_mask): + from vllm import LLM + + model = LLM(model="facebook/opt-125m") + m_generate = from_vllm( + model, from_text=from_text, generate=True, return_log_probs=True + ) + m_logprobs = from_vllm(model, from_text=from_text, generate=False) + self._check_lps( + m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False + ) + + def _check_lps( + self, + model_generate, + model_logprobs, + tokens, + attention_mask, + from_text, + has_logits, + ): + # Checks that the log-probs gathered with generate=False equate those with generate=True + tdin_genetate = self._make_data( + model_generate, tokens, attention_mask, True, from_text, has_logits + ) + td_generate = model_generate(tdin_genetate) + tdin_logprobs = self._make_data( + model_logprobs, + tokens, + attention_mask, + False, + from_text, + has_logits, + tokens_response=td_generate.tokens_response, + text_response=td_generate.text_response, + ) + td_logprobs = model_logprobs(tdin_logprobs) + torch.testing.assert_close( + td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2 + ) if __name__ == "__main__": diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 07a0880ba5b..e3e35219c65 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -143,7 +143,11 @@ def __init__( # self.action_key = unravel_key(action_key) if str2str: self.full_observation_spec_unbatched = Composite( - {self.str_key: NonTensor(example_data="a string", batched=True, shape=())} + { + self.str_key: NonTensor( + example_data="a string", batched=True, shape=() + ) + } ) self.full_action_spec_unbatched = Composite( {action_key: NonTensor(example_data="a string", batched=True, shape=())} diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 0befa7ef138..eb236e56c4b 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -942,7 +942,9 @@ def make_shape(shape): ) if is_tensor_collection(tensor) and not is_non_tensor(tensor) else NonTensor( - shape=tensor.shape, example_data=tensor.data, device=tensor.device + shape=tensor.shape, + example_data=tensor.data, + device=tensor.device, ) if is_non_tensor(tensor) else Unbounded( diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index e5b52a8a1f0..b5c5b699815 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -93,7 +93,7 @@ ) from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip -from .llm import from_hf_transformers +from .llm import from_hf_transformers, from_vllm __all__ = [ "Actor", @@ -178,6 +178,7 @@ "WorldModelWrapper", "distributions_maps", "from_hf_transformers", + "from_vllm", "get_primers_from_module", "recurrent_mode", "reset_noise", diff --git a/torchrl/modules/llm/__init__.py b/torchrl/modules/llm/__init__.py index 467ecfd24aa..5d70748aeff 100644 --- a/torchrl/modules/llm/__init__.py +++ b/torchrl/modules/llm/__init__.py @@ -4,5 +4,6 @@ # LICENSE file in the root directory of this source tree. from .transformers_policy import from_hf_transformers +from .vllm_policy import from_vllm -__all__ = ["from_hf_transformers"] +__all__ = ["from_hf_transformers", "from_vllm"] diff --git a/torchrl/modules/llm/transformers_policy.py b/torchrl/modules/llm/transformers_policy.py index 7494fe8b10b..c506481746e 100644 --- a/torchrl/modules/llm/transformers_policy.py +++ b/torchrl/modules/llm/transformers_policy.py @@ -2,12 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -# TODO: lazy imports +from __future__ import annotations import torch -import transformers from tensordict import NestedKey, TensorDictBase from tensordict.nn import ( TensorDictModule as Mod, @@ -17,7 +15,6 @@ ) from tensordict.tensorclass import NonTensorData, NonTensorStack from torchrl.data.llm import LLMData -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel def _maybe_clear_device(td): @@ -57,16 +54,18 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase: """ # TODO: how do we avoid getting these? + tokens_out = td["tokens_out", "sequences"] + seq_len = tokens_out.shape[1] + del td["tokens_out", "past_key_values"] scores = dict(td["tokens_out", "scores"].items()) scores = torch.stack( [scores[str(k)] for k in range(len(scores))], 1 ) # shape (B, seq-len, vocab_size) logits = scores - scores.logsumexp(dim=-1, keepdim=True) - td["logits"] = scores + td["logits"] = scores[..., -seq_len:, :] del td["tokens_out", "scores"] - seq_len = scores.shape[1] - tokens = td["tokens_out", "sequences"][..., -seq_len:] # shape (B, seq-len) + tokens = tokens_out[..., -seq_len:] # shape (B, seq-len) log_probs = logits.gather(-1, tokens.unsqueeze(-1)) td["log_probs"] = log_probs return td @@ -105,22 +104,84 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: def from_hf_transformers( - model: transformers.modeling_utils.PreTrainedModel, + model: transformers.modeling_utils.PreTrainedModel, # noqa *, generate: bool = True, return_log_probs: bool = True, - tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None, + tokenizer: transformers.tokenization_utils.PreTrainedTokenizer + | None = None, # noqa from_text: bool = False, device: torch.device | None = None, - # Keys: - text_key: NestedKey = "text", - token_key: NestedKey = "tokens", - attention_mask_key: NestedKey = "attention_mask", kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, ) -> TensorDictModuleBase: + """Creates a TensorDictModule from a Hugging Face Transformers model. + + This allows for a consistent interface across various LLM engines. + This function facilitates text generation and log probability computation. + + Args: + model (transformers.modeling_utils.PreTrainedModel): The Hugging Face model to wrap. + generate (bool, optional): Whether to generate text. Defaults to `True`. + return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `True`. + tokenizer (transformers.tokenization_utils.PreTrainedTokenizer, optional): The tokenizer to use. Defaults to `None`. + from_text (bool, optional): Whether the input is text. Defaults to `False`. + device (torch.device, optional): The device to use for computation. Defaults to `None`. + kwargs (dict, optional): Additional arguments for the model's generate method. Defaults to `None`. + tokenizer_kwargs (dict, optional): Additional arguments for the tokenizer. Defaults to `None`. + + Returns: + TensorDictModuleBase: A configured TensorDictModule for the specified model. + + Input Keys: + + - If `from_text` is `True`: + + - "text": The input text to be tokenized. + + - If `from_text` is `False`: + + - "tokens": The input token sequences. + - "attention_mask": The attention mask for the tokens. + + Output Keys: + + - "tokens_response": The generated token sequences. + - "log_probs": The log probabilities of the generated tokens (if `return_log_probs` is `True`). + - "logits": The logits of the generated tokens (if applicable). + - "text_response": The generated text (if `from_text` is `True` and `generate` is `True`). + + Example: + >>> from tensordict.tensorclass import NonTensorStack + >>> from transformers import AutoTokenizer, GPT2LMHeadModel, GPT2Config + >>> + >>> from torchrl.data import LLMData + >>> from torchrl.modules import from_hf_transformers + >>> + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = GPT2LMHeadModel(GPT2Config()) + >>> tokenizer.pad_token = tokenizer.eos_token + >>> + >>> module = from_hf_transformers( + ... model, + ... tokenizer=tokenizer, + ... from_text=True, + ... generate=True + ... ) + >>> input_data = LLMData(text=NonTensorStack("Hello, world!"), batch_size=1) + >>> output_data = module(input_data) + >>> print(output_data.text_response) + [' heritageillon rooft rooft Pear Tes grantingalde 58ocrocrocrocrcubecubecubecubecubecubecube'] + + .. seealso:: :func:`~torchrl.modules.from_vllm` for a similar interface using the vLLM library. + + """ # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks + # Keys: + text_key: NestedKey = "text" + token_key: NestedKey = "tokens" + attention_mask_key: NestedKey = "attention_mask" module_dict = {} if device: @@ -138,13 +199,59 @@ def from_hf_transformers( if tokenizer_kwargs.setdefault("padding_side", "left") != "left": raise RuntimeError - module_dict["encode"] = Mod( - tokenizer, - in_keys=[text_key], - out_keys=["tokens_in"], - method_kwargs=tokenizer_kwargs, - strict=True, - # We don't need the text after this + if generate: + module_dict["encode"] = Mod( + tokenizer, + in_keys=[text_key], + out_keys=["tokens_in"], + method_kwargs=tokenizer_kwargs, + strict=True, + # We don't need the text after this + inplace=False, + ) + else: + module_dict["encode"] = Mod( + # TODO: make this work with many strings + # Tokenize both strings, and only the first + lambda x, y: ( + tokenizer([_x + _y for _x, _y in zip(x, y)], **tokenizer_kwargs), + tokenizer(x, **tokenizer_kwargs), + ), + in_keys=[text_key, "text_response"], + out_keys=["tokens_in", "tokens_response"], + strict=True, + inplace=False, + ) + + def select(x, y): + return x.apply(lambda _x, _y: _x[..., _y.shape[-1] :], y) + + module_dict["stack_response"] = Mod( + # Remove the init from the total tokens to get only the response tokens + select, + in_keys=["tokens_in", "tokens_response"], + out_keys=["tokens_response"], + strict=True, + ) + elif not generate: + + def stack_for_logprobs(tokens, tokens_response, attention_mask=None): + tokens = torch.cat([tokens, tokens_response], -1) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones(tokens_response.shape)], -1 + ) + return tokens, tokens_response, attention_mask + + module_dict["stack_response"] = Mod( + stack_for_logprobs, + in_keys=["tokens", "tokens_response", "attention_mask"], + out_keys=[ + ("tokens_in", "input_ids"), + ("tokens_response", "input_ids"), + ("tokens_in", "attention_mask"), + ], + strict=False, inplace=False, ) else: @@ -191,6 +298,17 @@ def from_hf_transformers( out_to_in_map=True, strict=True, ) + + # Keep only the new tokens + def remove_input_seq(tokens_in, tokens_out): + return tokens_out[..., tokens_in.shape[-1] :] + + module_dict["remove_input_seq"] = Mod( + remove_input_seq, + in_keys=[("tokens_in", "input_ids"), ("tokens_out", "sequences")], + out_keys=[("tokens_out", "sequences")], + strict=True, + ) if return_log_probs: module_dict["extract_log_probs"] = WrapModule( log_probs_from_scores, @@ -206,6 +324,7 @@ def from_hf_transformers( ) if device: module_dict["to_source_device"] = _maybe_set_device + module_dict["rebuild"] = Mod( lambda *x: x, in_keys=[ @@ -224,7 +343,8 @@ def from_hf_transformers( "log_probs", "logits", ], - strict=True, + # There may not be log_probs and logits + strict=False, inplace=False, ) else: @@ -241,8 +361,10 @@ def from_hf_transformers( kwargs = {} if not kwargs.setdefault("return_dict", True): raise RuntimeError - if not return_log_probs: - raise RuntimeError + if return_log_probs not in (True, None): + raise RuntimeError( + "return_log_probs should be True or None when not generating." + ) module_dict["get_logprobs"] = Mod( model, method_kwargs=kwargs, @@ -264,8 +386,8 @@ def from_hf_transformers( if from_text: module_dict["rebuild"] = Mod( lambda *x: x, - in_keys=["log_probs", "logits", ("tokens_in", "attention_mask")], - out_keys=["log_probs", "logits", "attention_mask"], + in_keys=["log_probs", "logits", "tokens_response"], + out_keys=["log_probs", "logits", "tokens_response"], inplace=False, ) else: @@ -280,6 +402,9 @@ def from_hf_transformers( if __name__ == "__main__": + import transformers + from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + max_seq_length = 50000 tokenizer = AutoTokenizer.from_pretrained("gpt2") diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py new file mode 100644 index 00000000000..daab91c76d0 --- /dev/null +++ b/torchrl/modules/llm/vllm_policy.py @@ -0,0 +1,432 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import collections +import importlib.util + +import torch +from tensordict import ( + from_dataclass, + maybe_dense_stack, + NestedKey, + NonTensorData, + NonTensorStack, + TensorClass, +) +from tensordict.nn import ( + TensorDictModule as Mod, + TensorDictModuleBase, + TensorDictSequential as Seq, +) +from tensordict.utils import _zip_strict + +from torchrl.data import LLMData + +_has_vllm = importlib.util.find_spec("vllm") + +if _has_vllm: + import vllm + + CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput) +else: + CompletionOutput_tc = None + + +def _maybe_clear_device(td): + if td.device is None: + return td + return td.set(NonTensorData("_source_device"), td.device).clear_device_() + + +def _maybe_set_device(td): + device = td.pop("_source_device", None) + if device is None: + return td + elif isinstance(device, NonTensorData): + device: torch.device = device.data + return td.to(device) + + +def from_vllm( + model: vllm.LLM, # noqa + *, + return_log_probs: bool = False, + tokenizer: transformers.tokenization_utils.PreTrainedTokenizer # noqa + | None = None, # noqa + from_text: bool = False, + device: torch.device | None = None, + generate: bool = True, + generate_kwargs: dict | None = None, + tokenizer_kwargs: dict | None = None, +) -> TensorDictModuleBase: + """Creates a TensorDictModule from a vLLM model. + + This function provides a consistent interface across various LLM engines. + + It supports text generation and log probability computation, similar to the Hugging Face Transformers interface. + + Args: + model (LLM): The vLLM model to wrap. + return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `False`. + tokenizer (transformers.tokenization_utils.PreTrainedTokenizer, optional): The tokenizer to use. Defaults to `None`. + from_text (bool, optional): Whether the input is text. Defaults to `False`. + device (torch.device, optional): The device to use for computation. Defaults to `None`. + generate (bool, optional): Whether to generate text. Defaults to `True`. + generate_kwargs (dict, optional): Additional arguments for the model's generate method. Defaults to `None`. + tokenizer_kwargs (dict, optional): Additional arguments for the tokenizer. Defaults to `None`. + + Returns: + TensorDictModuleBase: A configured TensorDictModule for the specified model. + + Input Keys: + + - If `from_text` is `True`: + + - "text": The input text to be tokenized. + + - If `from_text` is False: + + - "tokens": The input token sequences. + - "attention_mask": The attention mask for the tokens. + + Output Keys: + + - "tokens_response": The generated token sequences. + - "log_probs": The log probabilities of the generated tokens (if `return_log_probs` is True). + - "text_response": The generated text (if `from_text` is True and `generate` is True). + + Example: + >>> from vllm import LLM + >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = LLM(model="facebook/opt-125m") + >>> module = from_vllm( + ... model, + ... tokenizer=tokenizer, + ... from_text=True, + ... generate=True + ... ) + >>> input_data = LLMData(text=NonTensorStack("Hello, world!"), batch_size=1) + >>> output_data = module(input_data) + >>> print(output_data.text_response) + + .. seealso:: :func:`~torchrl.modules.from_hf_transformers` for a similar interface using the Hugging Face + Transformers library. + + """ + try: + from vllm import SamplingParams + except ImportError: + raise ImportError("Please install `vllm` to use `from_vllm`.") + + text_key: NestedKey = ("text",) + token_key: NestedKey = ("tokens",) + attention_mask_key: NestedKey = ("attention_mask",) + + # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks + if tokenizer is None: + tokenizer = model.get_tokenizer() + module_dict = {} + if device: + module_dict["clear_device"] = _maybe_clear_device + if from_text: + if not tokenizer_kwargs: + tokenizer_kwargs = {} + if not tokenizer_kwargs.setdefault("return_attention_mask", True): + raise RuntimeError + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + if tokenizer_kwargs.setdefault("padding", True) not in (True,): + raise RuntimeError + if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + raise RuntimeError + + if generate: + module_dict["encode"] = Mod( + tokenizer, + in_keys=[text_key], + out_keys=["tokens_in"], + method_kwargs=tokenizer_kwargs, + strict=True, + inplace=False, + ) + else: + module_dict["encode"] = Mod( + # TODO: make this work with many strings + # Tokenize both strings, and only the first + lambda x, y: ( + tokenizer([_x + _y for _x, _y in zip(x, y)], **tokenizer_kwargs), + tokenizer(x, **tokenizer_kwargs), + ), + in_keys=[text_key, "text_response"], + out_keys=["tokens_in", "tokens_response"], + strict=True, + inplace=False, + ) + + def select(x, y): + return x.apply(lambda _x, _y: _x[..., _y.shape[-1] :], y) + + module_dict["stack_response"] = Mod( + # Remove the init from the total tokens to get only the response tokens + select, + in_keys=["tokens_in", "tokens_response"], + out_keys=["tokens_response"], + strict=True, + ) + elif not generate: + + def stack_for_logprobs(tokens, tokens_response, attention_mask=None): + tokens = torch.cat([tokens, tokens_response], -1) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones(tokens_response.shape)], -1 + ) + return tokens, tokens_response, attention_mask + + module_dict["stack_response"] = Mod( + stack_for_logprobs, + in_keys=["tokens", "tokens_response", "attention_mask"], + out_keys=[ + ("tokens_in", "input_ids"), + ("tokens_response", "input_ids"), + ("tokens_in", "attention_mask"), + ], + strict=False, + inplace=False, + ) + else: + module_dict["move_inputs"] = Mod( + lambda *x: x, + in_keys=["tokens", "attention_mask"], + out_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], + # It's ok if there's no mask + strict=False, + inplace=False, + ) + + def to_list(tokens, attention_mask): + """Converts a tensor of integer in a masked list (of lists) of integers.""" + if isinstance(tokens, torch.Tensor): + # TODO: make this an ND NonTensorStack + parent = [] + queue = collections.deque() + if attention_mask is None: + attention_mask = torch.ones_like(tokens) + queue.append((tokens, attention_mask.bool(), parent)) + while queue: + token, amask, _parent = queue.popleft() + if token.ndim == 1: + _parent.extend(token[amask].tolist()) + else: + _parent.extend([[] for _ in range(token.shape[0])]) + queue.extend( + [ + (t, m, local_parent) + for t, m, local_parent in zip(token, amask, _parent) + ] + ) + tokens = parent + return NonTensorStack(*tokens) + + module_dict["to_list"] = Mod( + to_list, + in_keys=[("tokens_in", "input_ids"), ("tokens_in", "attention_mask")], + out_keys=[("tokens_in", "input_ids_list")], + strict=False, + ) + + if generate_kwargs is None: + generate_kwargs = { + "detokenize": False, + "prompt_logprobs": not generate, + "logprobs": return_log_probs, + } + if not generate: + generate_kwargs["max_tokens"] = 1 + sampling_params = SamplingParams(**generate_kwargs) + + module_dict["generate"] = Mod( + model, + method="generate", + method_kwargs={"sampling_params": sampling_params}, + in_keys={ + "prompt_token_ids": ("tokens_in", "input_ids_list"), + }, + out_keys=["tokens_out"], + out_to_in_map=True, + strict=True, + ) + + def get_output_tokens_and_log_probs(td): + td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) + if generate: + # When not generate, we don't want to overwrite this + td["tokens_response"] = td["tokens_out"].outputs.token_ids + if return_log_probs: + td["log_probs"] = td["tokens_out"].outputs.logprobs.unsqueeze(-1) + elif not generate: + td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs.unsqueeze(-1) + return td + + module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs + + if not generate: + + def translate_lps(tokens_response, x): + # we disregard the tokens from the prompt to focus on those of the response + return x[..., -tokens_response.shape[-1] :, :] + + module_dict["translate_lps"] = Mod( + translate_lps, + in_keys=[("tokens_response", "input_ids"), "prompt_logprobs"], + out_keys=["log_probs"], + ) + elif from_text: + module_dict["decode"] = Mod( + tokenizer.batch_decode, + in_keys=["tokens_response"], + out_keys=["text_response"], + ) + + if device: + module_dict["to_source_device"] = _maybe_set_device + + if generate: + module_dict["format"] = Mod( + lambda *x: x, + in_keys=[ + "log_probs", + "tokens_response", + ("tokens_in", "input_ids"), + ("tokens_in", "attention_mask"), + "text_response", + ], + out_keys=[ + "log_probs", + "tokens_response", + token_key, + attention_mask_key, + "text_response", + ], + strict=False, + inplace=False, + ) + else: + module_dict["format"] = Mod( + lambda *x: x, + in_keys=["log_probs", "tokens_response"], + out_keys=["log_probs", "tokens_response"], + strict=False, + inplace=False, + ) + + return Seq(module_dict, inplace=True) + + +class _RequestOutput_tc(TensorClass["nocast"]): + request_id: str + prompt: str + prompt_token_ids: str + prompt_logprobs: str + outputs: str + finished: str + metrics: str + lora_request: str + encoder_prompt: str + encoder_prompt_token_ids: str + num_cached_tokens: str + + def __post_init__(self): + def postproc(output): + def get_logprob(output): + t = [] + for v, tid in zip(output.logprobs, output.token_ids): + t.append( + v[tid]["logprob"] if v[tid].get("logprob") is not None else 0.0 + ) + return torch.tensor(t) + + if output.logprobs: + output.logprobs = get_logprob(output) + output.token_ids = torch.tensor(output.token_ids) + return output + + if isinstance(self.outputs, list): + outputs = self.outputs + outputs = [ + postproc(from_dataclass(output, dest_cls=CompletionOutput_tc)) + for output in outputs + ] + if len(outputs) == 1: + self.outputs = outputs[0] + else: + self.outputs = torch.stack(outputs) + self.prompt_logprobs = torch.tensor( + [ + v[tid].logprob if v is not None else 0.0 + for v, tid in _zip_strict( + self.prompt_logprobs, self.prompt_token_ids + ) + ] + ) + self.prompt_token_ids = torch.tensor(self.prompt_token_ids) + self.num_cached_tokens = torch.tensor(self.num_cached_tokens) + + @classmethod + def from_request_output(cls, requests): + out = maybe_dense_stack( + [ + cls( + request_id=request.request_id, + prompt=request.prompt, + prompt_token_ids=request.prompt_token_ids, + prompt_logprobs=request.prompt_logprobs, + outputs=request.outputs, + finished=request.finished, + metrics=request.metrics, + lora_request=request.lora_request, + encoder_prompt=request.encoder_prompt, + encoder_prompt_token_ids=request.encoder_prompt_token_ids, + num_cached_tokens=request.num_cached_tokens, + ) + for request in requests + ] + ) + return out + + +if __name__ == "__main__": + from vllm import LLM, SamplingParams + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + llm = LLM(model="facebook/opt-125m") + outputs = llm.generate(prompts, sampling_params) + m = from_vllm(llm, from_text=True) + + td = m(LLMData(text=NonTensorStack("a text"), batch_size=1)) + + td = m(LLMData(text=NonTensorData("a text"), batch_size=())) + + td = m(LLMData(text=NonTensorStack("a text"), batch_size=1)) + m = from_vllm(llm, from_text=True, generate=False) + assert td.copy().text == ["a text"] + td_lp = LLMData( + text=NonTensorStack("a text"), + text_response=NonTensorStack(*td.text_response), + batch_size=(1,), + ) + td_lp = m(td_lp) + # torch.testing.assert_close(td.log_probs, td_lp.log_probs) + + m = from_vllm(llm, from_text=True, generate=True) + td = m(LLMData(text=NonTensorStack("a text", "another text here"), batch_size=2)) diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 2bd09e81e81..81d96fb7cec 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -436,7 +436,7 @@ class VmapModule(TensorDictModuleBase): >>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() """ - def __init__(self, module: TensorDictModuleBase, vmap_dim=None): + def __init__(self, module: TensorDictModuleBase, vmap_dim=None, mock: bool = False): if not _has_functorch: raise ImportError("VmapModule requires torch>=2.0.") super().__init__() @@ -444,6 +444,7 @@ def __init__(self, module: TensorDictModuleBase, vmap_dim=None): self.out_keys = module.out_keys self.module = module self.vmap_dim = vmap_dim + self.mock = mock if torch.__version__ >= "2.0": self._vmap = torch.vmap else: @@ -451,6 +452,9 @@ def __init__(self, module: TensorDictModuleBase, vmap_dim=None): self._vmap = functorch.vmap + def mock_(self, value: bool = True): + self.mock = value + def forward(self, tensordict): # TODO: there is a risk of segfault if input is not a tensordict. # We should investigate (possibly prevent it c++ side?) @@ -458,7 +462,12 @@ def forward(self, tensordict): if vmap_dim is None: ndim = tensordict.ndim vmap_dim = ndim - 1 - td = self._vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict) + if self.mock: + td = torch.stack( + [self.module(_td) for _td in tensordict.unbind(vmap_dim)], vmap_dim + ) + else: + td = self._vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict) return tensordict.update(td)