From dbb9af1edbdc50d717774a2de415c6bc3f0d96e8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 5 Mar 2025 11:15:21 -0800 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- torchrl/modules/llm/vllm.py | 129 ++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 torchrl/modules/llm/vllm.py diff --git a/torchrl/modules/llm/vllm.py b/torchrl/modules/llm/vllm.py new file mode 100644 index 00000000000..4c400666c32 --- /dev/null +++ b/torchrl/modules/llm/vllm.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +import transformers +from tensordict import NestedKey, NonTensorData, NonTensorStack, TensorDict +from tensordict.nn import ( + TensorDictModule as Mod, + TensorDictModuleBase, + TensorDictSequential as Seq, +) +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams + + +def _maybe_clear_device(td): + if td.device is None: + return td + return td.set(NonTensorData("_source_device"), td.device).clear_device_() + + +def _maybe_set_device(td): + device = td.pop("_source_device", None) + if device is None: + return td + elif isinstance(device, NonTensorData): + device: torch.device = device.data + return td.to(device) + + +def from_vllm( + model: LLM, + return_log_probs: bool = False, + tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None, + from_text: bool = False, + device: torch.device | None = None, + text_key: NestedKey = "text", + generate_kwargs: dict | None = None, + tokenizer_kwargs: dict | None = None, +) -> TensorDictModuleBase: + # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks + module_dict = {} + if device: + module_dict["clear_device"] = _maybe_clear_device + if from_text: + if not tokenizer_kwargs: + tokenizer_kwargs = {} + if not tokenizer_kwargs.setdefault("return_attention_mask", True): + raise RuntimeError + if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": + raise RuntimeError + if tokenizer_kwargs.setdefault("padding", True) not in (True,): + raise RuntimeError + if tokenizer_kwargs.setdefault("padding_side", "left") != "left": + raise RuntimeError + module_dict["encode"] = Mod( + tokenizer, + in_keys=[text_key], + out_keys=["tokens_in"], # method_kwargs=tokenizer_kwargs, + strict=True, + ) + + # FIXME: this is not great! + def f(td): + td["tokens_in", "input_ids"] = NonTensorStack( + *td["tokens_in", "input_ids"].tolist() + ) + print("td['tokens_in', 'input_ids']", td["tokens_in", "input_ids"]) + return td + + module_dict["to_list"] = f + + if generate_kwargs is None: + generate_kwargs = { + "detokenize": False, + "prompt_logprobs": return_log_probs, + "logprobs": return_log_probs, + } + sampling_params = SamplingParams(**generate_kwargs) + + module_dict["generate"] = Mod( + model, + method="generate", + method_kwargs={"sampling_params": sampling_params}, + in_keys={ + "prompt_token_ids": ("tokens_in", "input_ids"), + # "attention_mask": ("tokens_in", "attention_mask"), + }, + out_keys=["tokens_out"], + out_to_in_map=True, + strict=True, + ) + + def get_output_tokens_and_log_probs(td): + # FIXME: shouldn't have to be doing 0 index here to make sure this works with batches + td["output_tokens"] = td["tokens_out"][0].outputs[0].token_ids + # FIXME: this is not in a tensor form yet but uses their own LogProb object + td["log_probs"] = td["tokens_out"][0].outputs[0].logprobs + return td + + module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs + + # module_dict["extract_log_probs"] = WrapModule(log_probs_from_logits, in_keys=[("tokens_in", "sequences"), ("tokens_in", "scores")], out_keys=["logits", "log_probs"]) + if from_text: + module_dict["decode"] = Mod( + tokenizer.batch_decode, + in_keys=["output_tokens"], # in_keys=["tokens_out", "sequences"], + out_keys=["action"], # strict=True, + ) + if device: + module_dict["to_source_device"] = _maybe_set_device + + return Seq(module_dict) + + +if __name__ == "__main__": + max_seq_length = 50000 + model_name = "Qwen/Qwen2.5-7B-Instruct" + model = LLM(model_name, skip_tokenizer_init=True, device="cuda:0") + model.llm_engine.model_executor.driver_worker.worker.model_runner.model.sampler.include_gpu_probs_tensor = ( + True + ) + tokenizer = AutoTokenizer.from_pretrained(model_name, device="cuda:0") + # tokenizer.padding_side = "left" + m = from_vllm(model, tokenizer=tokenizer, from_text=True, device="cuda:0") + print(m(TensorDict(text="a text is a text"))) From bef7a20d63b8cc1cf7297327b610eb0b4dc3d1a7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 11 Mar 2025 10:32:44 +0000 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- .../environment.yml | 3 +- .../{scripts_rlhf => scripts_llm}/install.sh | 35 ++-- .../post_process.sh | 0 .../run-clang-format.py | 0 .../{scripts_rlhf => scripts_llm}/run_test.sh | 2 + .../setup_env.sh | 0 ...test-linux-rlhf.yml => test-linux-llm.yml} | 10 +- test/test_actors.py | 112 +++++++++--- torchrl/envs/custom/llm.py | 6 +- torchrl/modules/llm/transformers_policy.py | 163 ++++++++++++++++-- torchrl/modules/llm/vllm_policy.py | 90 ++++++++-- 11 files changed, 340 insertions(+), 81 deletions(-) rename .github/unittest/linux_libs/{scripts_rlhf => scripts_llm}/environment.yml (91%) rename .github/unittest/linux_libs/{scripts_rlhf => scripts_llm}/install.sh (59%) rename .github/unittest/linux_libs/{scripts_rlhf => scripts_llm}/post_process.sh (100%) rename .github/unittest/linux_libs/{scripts_rlhf => scripts_llm}/run-clang-format.py (100%) rename .github/unittest/linux_libs/{scripts_rlhf => scripts_llm}/run_test.sh (86%) rename .github/unittest/linux_libs/{scripts_rlhf => scripts_llm}/setup_env.sh (100%) rename .github/workflows/{test-linux-rlhf.yml => test-linux-llm.yml} (83%) diff --git a/.github/unittest/linux_libs/scripts_rlhf/environment.yml b/.github/unittest/linux_libs/scripts_llm/environment.yml similarity index 91% rename from .github/unittest/linux_libs/scripts_rlhf/environment.yml rename to .github/unittest/linux_libs/scripts_llm/environment.yml index 6b8800a2531..b0897796779 100644 --- a/.github/unittest/linux_libs/scripts_rlhf/environment.yml +++ b/.github/unittest/linux_libs/scripts_llm/environment.yml @@ -17,5 +17,6 @@ dependencies: - pyyaml - scipy - hydra-core - - transformers<4.42.0 + - transformers - datasets + - vllm diff --git a/.github/unittest/linux_libs/scripts_rlhf/install.sh b/.github/unittest/linux_libs/scripts_llm/install.sh similarity index 59% rename from .github/unittest/linux_libs/scripts_rlhf/install.sh rename to .github/unittest/linux_libs/scripts_llm/install.sh index 1e927e08df6..68e63bf58ca 100755 --- a/.github/unittest/linux_libs/scripts_rlhf/install.sh +++ b/.github/unittest/linux_libs/scripts_llm/install.sh @@ -26,23 +26,24 @@ fi # submodules git submodule sync && git submodule update --init --recursive -printf "Installing PyTorch with cu128" -if [[ "$TORCH_VERSION" == "nightly" ]]; then - if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U - else - pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U - fi -elif [[ "$TORCH_VERSION" == "stable" ]]; then - if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu - else - pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128 - fi -else - printf "Failed to install pytorch" - exit 1 -fi +# We skip pytorch install due to vllm requirements +#printf "Installing PyTorch with cu128" +#if [[ "$TORCH_VERSION" == "nightly" ]]; then +# if [ "${CU_VERSION:-}" == cpu ] ; then +# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cpu -U +# else +# pip3 install --pre torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/nightly/cu128 -U +# fi +#elif [[ "$TORCH_VERSION" == "stable" ]]; then +# if [ "${CU_VERSION:-}" == cpu ] ; then +# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cpu +# else +# pip3 install torch "numpy<2.0.0" --index-url https://download.pytorch.org/whl/cu128 +# fi +#else +# printf "Failed to install pytorch" +# exit 1 +#fi # install tensordict if [[ "$RELEASE" == 0 ]]; then diff --git a/.github/unittest/linux_libs/scripts_rlhf/post_process.sh b/.github/unittest/linux_libs/scripts_llm/post_process.sh similarity index 100% rename from .github/unittest/linux_libs/scripts_rlhf/post_process.sh rename to .github/unittest/linux_libs/scripts_llm/post_process.sh diff --git a/.github/unittest/linux_libs/scripts_rlhf/run-clang-format.py b/.github/unittest/linux_libs/scripts_llm/run-clang-format.py similarity index 100% rename from .github/unittest/linux_libs/scripts_rlhf/run-clang-format.py rename to .github/unittest/linux_libs/scripts_llm/run-clang-format.py diff --git a/.github/unittest/linux_libs/scripts_rlhf/run_test.sh b/.github/unittest/linux_libs/scripts_llm/run_test.sh similarity index 86% rename from .github/unittest/linux_libs/scripts_rlhf/run_test.sh rename to .github/unittest/linux_libs/scripts_llm/run_test.sh index dcfc686ade0..a563e2e329b 100755 --- a/.github/unittest/linux_libs/scripts_rlhf/run_test.sh +++ b/.github/unittest/linux_libs/scripts_llm/run_test.sh @@ -24,6 +24,8 @@ python -c "import transformers, datasets" python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_actors.py -k llm --instafail -v --durations 200 --capture no --error-for-skips --runslow + python .github/unittest/helpers/coverage_run_parallel.py examples/rlhf/train_rlhf.py \ sys.device=cuda:0 sys.ref_device=cuda:0 \ model.name_or_path=gpt2 train.max_epochs=2 \ diff --git a/.github/unittest/linux_libs/scripts_rlhf/setup_env.sh b/.github/unittest/linux_libs/scripts_llm/setup_env.sh similarity index 100% rename from .github/unittest/linux_libs/scripts_rlhf/setup_env.sh rename to .github/unittest/linux_libs/scripts_llm/setup_env.sh diff --git a/.github/workflows/test-linux-rlhf.yml b/.github/workflows/test-linux-llm.yml similarity index 83% rename from .github/workflows/test-linux-rlhf.yml rename to .github/workflows/test-linux-llm.yml index 994667b8a66..4de8b8165d9 100644 --- a/.github/workflows/test-linux-rlhf.yml +++ b/.github/workflows/test-linux-llm.yml @@ -1,4 +1,4 @@ -name: RLHF Tests on Linux +name: LLM Tests on Linux on: pull_request: @@ -50,7 +50,7 @@ jobs: export TF_CPP_MIN_LOG_LEVEL=0 export TD_GET_DEFAULTS_TO_NONE=1 - bash .github/unittest/linux_libs/scripts_rlhf/setup_env.sh - bash .github/unittest/linux_libs/scripts_rlhf/install.sh - bash .github/unittest/linux_libs/scripts_rlhf/run_test.sh - bash .github/unittest/linux_libs/scripts_rlhf/post_process.sh + bash .github/unittest/linux_libs/scripts_llm/setup_env.sh + bash .github/unittest/linux_libs/scripts_llm/install.sh + bash .github/unittest/linux_libs/scripts_llm/run_test.sh + bash .github/unittest/linux_libs/scripts_llm/post_process.sh diff --git a/test/test_actors.py b/test/test_actors.py index 8f03ca1d56a..bd0e3b5ed6f 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -919,54 +919,108 @@ def test_lmhead_actorvalueoperator(device): @pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies") @pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies") -class TestTransformerActor: +class TestLLMActor: @pytest.mark.parametrize( - "from_text, generate, tokens, attention_mask", + "from_text, generate, return_log_probs, tokens, attention_mask", [ - (True, True, None, None), - (True, False, None, None), + (True, True, True, None, None), + (True, True, False, None, None), + (True, False, None, None, None), ( False, True, + True, torch.randint(1024, (1, 10)), torch.ones(1, 10, dtype=torch.int64), ), - (False, True, torch.randint(1024, (1, 10)), None), + (False, True, True, torch.randint(1024, (1, 10)), None), + ( + False, + True, + False, + torch.randint(1024, (1, 10)), + torch.ones(1, 10, dtype=torch.int64), + ), + (False, True, False, torch.randint(1024, (1, 10)), None), ], ) - def test_from_hf_transformers(self, from_text, generate, tokens, attention_mask): + def test_from_hf_transformers( + self, from_text, generate, return_log_probs, tokens, attention_mask + ): from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + model_name = "distilbert-base-uncased" # or "minilm" or "albert-tiny" + # Load the model and tokenizer + # model = AutoModel.from_pretrained(model_name) + # tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer.pad_token = tokenizer.eos_token model = GPT2LMHeadModel(GPT2Config()) + + tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" + m = from_hf_transformers( - model, tokenizer=tokenizer, from_text=from_text, generate=generate + model, + tokenizer=tokenizer, + from_text=from_text, + generate=generate, + return_log_probs=return_log_probs, + ) + self._run_check( + m, + tokens, + attention_mask, + generate, + return_log_probs, + from_text, + has_logits=True, ) - self._run_check(m, tokens, attention_mask, generate, from_text, has_logits=True) @pytest.mark.parametrize( - "from_text, generate, tokens, attention_mask", + "from_text, generate, return_log_probs, tokens, attention_mask", [ - (True, True, None, None), - (True, False, None, None), + (True, True, True, None, None), + (True, True, False, None, None), + (True, False, None, None, None), + ( + False, + True, + True, + torch.randint(1024, (1, 10)), + torch.ones(1, 10, dtype=torch.int64), + ), + (False, True, True, torch.randint(1024, (1, 10)), None), ( False, True, + False, torch.randint(1024, (1, 10)), torch.ones(1, 10, dtype=torch.int64), ), - (False, True, torch.randint(1024, (1, 10)), None), + (False, True, False, torch.randint(1024, (1, 10)), None), ], ) - def test_from_vllm(self, from_text, generate, tokens, attention_mask): + def test_from_vllm( + self, from_text, generate, return_log_probs, tokens, attention_mask + ): from vllm import LLM model = LLM(model="facebook/opt-125m") - m = from_vllm(model, from_text=from_text, generate=generate) + m = from_vllm( + model, + from_text=from_text, + generate=generate, + return_log_probs=return_log_probs, + ) self._run_check( - m, tokens, attention_mask, generate, from_text, has_logits=False + m, + tokens, + attention_mask, + generate, + return_log_probs, + from_text, + has_logits=False, ) def _make_data( @@ -1007,7 +1061,16 @@ def _make_data( ) return tdin - def _run_check(self, m, tokens, attention_mask, generate, from_text, has_logits): + def _run_check( + self, + m, + tokens, + attention_mask, + generate, + return_log_probs, + from_text, + has_logits, + ): tdin = self._make_data( m, tokens, attention_mask, generate, from_text, has_logits ) @@ -1024,13 +1087,19 @@ def _run_check(self, m, tokens, attention_mask, generate, from_text, has_logits) if generate and (attention_mask is not None or from_text): assert td.attention_mask is not None, (generate, generate, from_text) else: - assert td.attention_mask is None + assert td.attention_mask is None, (generate, from_text) if not generate: # logprobs are computed on text response of tokens_response assert td.text_response is not None or td.tokens_response is not None assert td.log_probs is not None if has_logits: assert td.logits is not None + if generate: + if return_log_probs: + assert td.log_probs is not None + assert td.log_probs.shape[-2] == td.tokens_response.shape[-1] + else: + assert td.log_probs is None # Test the shapes assert td.tokens_response is not None, (generate, has_logits, from_text) @@ -1042,7 +1111,7 @@ def _run_check(self, m, tokens, attention_mask, generate, from_text, has_logits) assert ( td.tokens_response[..., : td.tokens.shape[-1]] != td.tokens[..., : td.tokens_response.shape[-1]] - ).any() + ).any(), (generate, from_text) @pytest.mark.parametrize( "from_text, tokens, attention_mask", @@ -1060,7 +1129,9 @@ def test_from_vllm_logprobs(self, from_text, tokens, attention_mask): from vllm import LLM model = LLM(model="facebook/opt-125m") - m_generate = from_vllm(model, from_text=from_text, generate=True) + m_generate = from_vllm( + model, from_text=from_text, generate=True, return_log_probs=True + ) m_logprobs = from_vllm(model, from_text=from_text, generate=False) self._check_lps( m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False @@ -1091,7 +1162,6 @@ def _check_lps( text_response=td_generate.text_response, ) td_logprobs = model_logprobs(tdin_logprobs) - print(td_generate.log_probs / td_logprobs.log_probs) torch.testing.assert_close( td_generate.log_probs, td_logprobs.log_probs, rtol=1e-2, atol=1e-2 ) diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 07a0880ba5b..e3e35219c65 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -143,7 +143,11 @@ def __init__( # self.action_key = unravel_key(action_key) if str2str: self.full_observation_spec_unbatched = Composite( - {self.str_key: NonTensor(example_data="a string", batched=True, shape=())} + { + self.str_key: NonTensor( + example_data="a string", batched=True, shape=() + ) + } ) self.full_action_spec_unbatched = Composite( {action_key: NonTensor(example_data="a string", batched=True, shape=())} diff --git a/torchrl/modules/llm/transformers_policy.py b/torchrl/modules/llm/transformers_policy.py index cc8a2237ea7..db68d00529f 100644 --- a/torchrl/modules/llm/transformers_policy.py +++ b/torchrl/modules/llm/transformers_policy.py @@ -57,16 +57,18 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase: """ # TODO: how do we avoid getting these? + tokens_out = td["tokens_out", "sequences"] + seq_len = tokens_out.shape[1] + del td["tokens_out", "past_key_values"] scores = dict(td["tokens_out", "scores"].items()) scores = torch.stack( [scores[str(k)] for k in range(len(scores))], 1 ) # shape (B, seq-len, vocab_size) logits = scores - scores.logsumexp(dim=-1, keepdim=True) - td["logits"] = scores + td["logits"] = scores[..., -seq_len:, :] del td["tokens_out", "scores"] - seq_len = scores.shape[1] - tokens = td["tokens_out", "sequences"][..., -seq_len:] # shape (B, seq-len) + tokens = tokens_out[..., -seq_len:] # shape (B, seq-len) log_probs = logits.gather(-1, tokens.unsqueeze(-1)) td["log_probs"] = log_probs return td @@ -112,15 +114,77 @@ def from_hf_transformers( tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None, from_text: bool = False, device: torch.device | None = None, - # Keys: - text_key: NestedKey = "text", - token_key: NestedKey = "tokens", - attention_mask_key: NestedKey = "attention_mask", kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, ) -> TensorDictModuleBase: + """Creates a TensorDictModule from a Hugging Face Transformers model. + + This allows for a consistent interface across various LLM engines. + This function facilitates text generation and log probability computation. + + Args: + model (transformers.modeling_utils.PreTrainedModel): The Hugging Face model to wrap. + generate (bool, optional): Whether to generate text. Defaults to `True`. + return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `True`. + tokenizer (transformers.tokenization_utils.PreTrainedTokenizer, optional): The tokenizer to use. Defaults to `None`. + from_text (bool, optional): Whether the input is text. Defaults to `False`. + device (torch.device, optional): The device to use for computation. Defaults to `None`. + kwargs (dict, optional): Additional arguments for the model's generate method. Defaults to `None`. + tokenizer_kwargs (dict, optional): Additional arguments for the tokenizer. Defaults to `None`. + + Returns: + TensorDictModuleBase: A configured TensorDictModule for the specified model. + + Input Keys: + + - If `from_text` is `True`: + + - "text": The input text to be tokenized. + + - If `from_text` is `False`: + + - "tokens": The input token sequences. + - "attention_mask": The attention mask for the tokens. + + Output Keys: + + - "tokens_response": The generated token sequences. + - "log_probs": The log probabilities of the generated tokens (if `return_log_probs` is `True`). + - "logits": The logits of the generated tokens (if applicable). + - "text_response": The generated text (if `from_text` is `True` and `generate` is `True`). + + Example: + >>> from tensordict.tensorclass import NonTensorStack + >>> from transformers import AutoTokenizer, GPT2LMHeadModel, GPT2Config + >>> + >>> from torchrl.data import LLMData + >>> from torchrl.modules import from_hf_transformers + >>> + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = GPT2LMHeadModel(GPT2Config()) + >>> tokenizer.pad_token = tokenizer.eos_token + >>> + >>> module = from_hf_transformers( + ... model, + ... tokenizer=tokenizer, + ... from_text=True, + ... generate=True + ... ) + >>> input_data = LLMData(text=NonTensorStack("Hello, world!"), batch_size=1) + >>> output_data = module(input_data) + >>> print(output_data.text_response) + [' heritageillon rooft rooft Pear Tes grantingalde 58ocrocrocrocrcubecubecubecubecubecubecube'] + + .. seealso:: :func:`~torchrl.modules.from_vllm` for a similar interface using the vLLM library. + + """ # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks + # Keys: + text_key: NestedKey = "text" + token_key: NestedKey = "tokens" + attention_mask_key: NestedKey = "attention_mask" + module_dict = {} if device: module_dict["clear_device"] = _maybe_clear_device @@ -137,13 +201,59 @@ def from_hf_transformers( if tokenizer_kwargs.setdefault("padding_side", "left") != "left": raise RuntimeError - module_dict["encode"] = Mod( - tokenizer, - in_keys=[text_key], - out_keys=["tokens_in"], - method_kwargs=tokenizer_kwargs, - strict=True, - # We don't need the text after this + if generate: + module_dict["encode"] = Mod( + tokenizer, + in_keys=[text_key], + out_keys=["tokens_in"], + method_kwargs=tokenizer_kwargs, + strict=True, + # We don't need the text after this + inplace=False, + ) + else: + module_dict["encode"] = Mod( + # TODO: make this work with many strings + # Tokenize both strings, and only the first + lambda x, y: ( + tokenizer([_x + _y for _x, _y in zip(x, y)], **tokenizer_kwargs), + tokenizer(x, **tokenizer_kwargs), + ), + in_keys=[text_key, "text_response"], + out_keys=["tokens_in", "tokens_response"], + strict=True, + inplace=False, + ) + + def select(x, y): + return x.apply(lambda _x, _y: _x[..., _y.shape[-1] :], y) + + module_dict["stack_response"] = Mod( + # Remove the init from the total tokens to get only the response tokens + select, + in_keys=["tokens_in", "tokens_response"], + out_keys=["tokens_response"], + strict=True, + ) + elif not generate: + + def stack_for_logprobs(tokens, tokens_response, attention_mask=None): + tokens = torch.cat([tokens, tokens_response], -1) + if attention_mask is not None: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones(tokens_response.shape)], -1 + ) + return tokens, tokens_response, attention_mask + + module_dict["stack_response"] = Mod( + stack_for_logprobs, + in_keys=["tokens", "tokens_response", "attention_mask"], + out_keys=[ + ("tokens_in", "input_ids"), + ("tokens_response", "input_ids"), + ("tokens_in", "attention_mask"), + ], + strict=False, inplace=False, ) else: @@ -190,6 +300,17 @@ def from_hf_transformers( out_to_in_map=True, strict=True, ) + + # Keep only the new tokens + def remove_input_seq(tokens_in, tokens_out): + return tokens_out[..., tokens_in.shape[-1] :] + + module_dict["remove_input_seq"] = Mod( + remove_input_seq, + in_keys=[("tokens_in", "input_ids"), ("tokens_out", "sequences")], + out_keys=[("tokens_out", "sequences")], + strict=True, + ) if return_log_probs: module_dict["extract_log_probs"] = WrapModule( log_probs_from_scores, @@ -205,6 +326,7 @@ def from_hf_transformers( ) if device: module_dict["to_source_device"] = _maybe_set_device + module_dict["rebuild"] = Mod( lambda *x: x, in_keys=[ @@ -223,7 +345,8 @@ def from_hf_transformers( "log_probs", "logits", ], - strict=True, + # There may not be log_probs and logits + strict=False, inplace=False, ) else: @@ -240,8 +363,10 @@ def from_hf_transformers( kwargs = {} if not kwargs.setdefault("return_dict", True): raise RuntimeError - if not return_log_probs: - raise RuntimeError + if return_log_probs not in (True, None): + raise RuntimeError( + "return_log_probs should be True or None when not generating." + ) module_dict["get_logprobs"] = Mod( model, method_kwargs=kwargs, @@ -263,8 +388,8 @@ def from_hf_transformers( if from_text: module_dict["rebuild"] = Mod( lambda *x: x, - in_keys=["log_probs", "logits", ("tokens_in", "attention_mask")], - out_keys=["log_probs", "logits", "attention_mask"], + in_keys=["log_probs", "logits", "tokens_response"], + out_keys=["log_probs", "logits", "tokens_response"], inplace=False, ) else: diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index b5d2cf6bb99..24abdb343a3 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -50,13 +50,68 @@ def from_vllm( from_text: bool = False, device: torch.device | None = None, generate: bool = True, - # Keys - text_key: NestedKey = "text", - token_key: NestedKey = "tokens", - attention_mask_key: NestedKey = "attention_mask", generate_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, ) -> TensorDictModuleBase: + """Creates a TensorDictModule from a vLLM model. + + This function provides a consistent interface across various LLM engines. + + It supports text generation and log probability computation, similar to the Hugging Face Transformers interface. + + Args: + model (LLM): The vLLM model to wrap. + return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `False`. + tokenizer (transformers.tokenization_utils.PreTrainedTokenizer, optional): The tokenizer to use. Defaults to `None`. + from_text (bool, optional): Whether the input is text. Defaults to `False`. + device (torch.device, optional): The device to use for computation. Defaults to `None`. + generate (bool, optional): Whether to generate text. Defaults to `True`. + generate_kwargs (dict, optional): Additional arguments for the model's generate method. Defaults to `None`. + tokenizer_kwargs (dict, optional): Additional arguments for the tokenizer. Defaults to `None`. + + Returns: + TensorDictModuleBase: A configured TensorDictModule for the specified model. + + Input Keys: + + - If `from_text` is `True`: + + - "text": The input text to be tokenized. + + - If `from_text` is False: + + - "tokens": The input token sequences. + - "attention_mask": The attention mask for the tokens. + + Output Keys: + + - "tokens_response": The generated token sequences. + - "log_probs": The log probabilities of the generated tokens (if `return_log_probs` is True). + - "text_response": The generated text (if `from_text` is True and `generate` is True). + + Example: + >>> from vllm import LLM + >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> model = LLM(model="facebook/opt-125m") + >>> module = from_vllm( + ... model, + ... tokenizer=tokenizer, + ... from_text=True, + ... generate=True + ... ) + >>> input_data = LLMData(text=NonTensorStack("Hello, world!"), batch_size=1) + >>> output_data = module(input_data) + >>> print(output_data.text_response) + + .. seealso:: :func:`~torchrl.modules.from_hf_transformers` for a similar interface using the Hugging Face + Transformers library. + + """ + text_key: NestedKey = ("text",) + token_key: NestedKey = ("tokens",) + attention_mask_key: NestedKey = ("attention_mask",) + # TODO: Seq should have a return_log_prob and be of ProbabilisticTDSequential type for instance checks if tokenizer is None: tokenizer = model.get_tokenizer() @@ -171,7 +226,11 @@ def to_list(tokens, attention_mask): ) if generate_kwargs is None: - generate_kwargs = {"detokenize": False, "prompt_logprobs": 1, "logprobs": 1} + generate_kwargs = { + "detokenize": False, + "prompt_logprobs": not generate, + "logprobs": return_log_probs, + } if not generate: generate_kwargs["max_tokens"] = 1 sampling_params = SamplingParams(**generate_kwargs) @@ -189,13 +248,14 @@ def to_list(tokens, attention_mask): ) def get_output_tokens_and_log_probs(td): - td["tokens_out"] = RequestOutput_tc.from_request_output(td["tokens_out"]) + td["tokens_out"] = _RequestOutput_tc.from_request_output(td["tokens_out"]) if generate: # When not generate, we don't want to overwrite this td["tokens_response"] = td["tokens_out"].outputs.token_ids - td["log_probs"] = td["tokens_out"].outputs.logprobs - else: - td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs + if return_log_probs: + td["log_probs"] = td["tokens_out"].outputs.logprobs.unsqueeze(-1) + elif not generate: + td["prompt_logprobs"] = td["tokens_out"].prompt_logprobs.unsqueeze(-1) return td module_dict["get_output_tokens_and_log_probs"] = get_output_tokens_and_log_probs @@ -204,7 +264,7 @@ def get_output_tokens_and_log_probs(td): def translate_lps(tokens_response, x): # we disregard the tokens from the prompt to focus on those of the response - return x[..., -tokens_response.shape[-1] :] + return x[..., -tokens_response.shape[-1] :, :] module_dict["translate_lps"] = Mod( translate_lps, @@ -253,7 +313,7 @@ def translate_lps(tokens_response, x): return Seq(module_dict, inplace=True) -class RequestOutput_tc(TensorClass["nocast"]): +class _RequestOutput_tc(TensorClass["nocast"]): request_id: str prompt: str prompt_token_ids: str @@ -276,7 +336,8 @@ def get_logprob(output): ) return torch.tensor(t) - output.logprobs = get_logprob(output) + if output.logprobs: + output.logprobs = get_logprob(output) output.token_ids = torch.tensor(output.token_ids) return output @@ -337,10 +398,8 @@ def from_request_output(cls, requests): m = from_vllm(llm, from_text=True) td = m(LLMData(text=NonTensorStack("a text"), batch_size=1)) - print("result", td) td = m(LLMData(text=NonTensorData("a text"), batch_size=())) - print("result", td) td = m(LLMData(text=NonTensorStack("a text"), batch_size=1)) m = from_vllm(llm, from_text=True, generate=False) @@ -351,10 +410,7 @@ def from_request_output(cls, requests): batch_size=(1,), ) td_lp = m(td_lp) - print("td_lp", td_lp) - print(td.log_probs / td_lp.log_probs) # torch.testing.assert_close(td.log_probs, td_lp.log_probs) m = from_vllm(llm, from_text=True, generate=True) td = m(LLMData(text=NonTensorStack("a text", "another text here"), batch_size=2)) - print(td) From 19cc931ffe0ace90a4a6a112f8c6cd8aa5e9ced0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 11 Mar 2025 10:50:53 +0000 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- torchrl/modules/llm/transformers_policy.py | 13 +++++++------ torchrl/modules/llm/vllm_policy.py | 22 ++++++++++++++++------ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/torchrl/modules/llm/transformers_policy.py b/torchrl/modules/llm/transformers_policy.py index db68d00529f..c506481746e 100644 --- a/torchrl/modules/llm/transformers_policy.py +++ b/torchrl/modules/llm/transformers_policy.py @@ -2,12 +2,10 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - -# TODO: lazy imports +from __future__ import annotations import torch -import transformers from tensordict import NestedKey, TensorDictBase from tensordict.nn import ( TensorDictModule as Mod, @@ -17,7 +15,6 @@ ) from tensordict.tensorclass import NonTensorData, NonTensorStack from torchrl.data.llm import LLMData -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel def _maybe_clear_device(td): @@ -107,11 +104,12 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase: def from_hf_transformers( - model: transformers.modeling_utils.PreTrainedModel, + model: transformers.modeling_utils.PreTrainedModel, # noqa *, generate: bool = True, return_log_probs: bool = True, - tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None, + tokenizer: transformers.tokenization_utils.PreTrainedTokenizer + | None = None, # noqa from_text: bool = False, device: torch.device | None = None, kwargs: dict | None = None, @@ -404,6 +402,9 @@ def remove_input_seq(tokens_in, tokens_out): if __name__ == "__main__": + import transformers + from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + max_seq_length = 50000 tokenizer = AutoTokenizer.from_pretrained("gpt2") diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index 24abdb343a3..fdba71cb63e 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -2,11 +2,12 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import collections +import importlib.util import torch -import transformers -import vllm from tensordict import ( from_dataclass, maybe_dense_stack, @@ -22,9 +23,15 @@ ) from torchrl.data import LLMData -from vllm import LLM, SamplingParams -CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput) +_has_vllm = importlib.util.find_spec("vllm") + +if _has_vllm: + import vllm + + CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput) +else: + CompletionOutput_tc = None def _maybe_clear_device(td): @@ -43,10 +50,11 @@ def _maybe_set_device(td): def from_vllm( - model: LLM, + model: vllm.LLM, # noqa *, return_log_probs: bool = False, - tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None, + tokenizer: transformers.tokenization_utils.PreTrainedTokenizer # noqa + | None = None, # noqa from_text: bool = False, device: torch.device | None = None, generate: bool = True, @@ -386,6 +394,8 @@ def from_request_output(cls, requests): if __name__ == "__main__": + from vllm import LLM, SamplingParams + prompts = [ "Hello, my name is", "The president of the United States is", From cb64cd51c41b83920cc0db623d5681ebe6fb2225 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 11 Mar 2025 11:22:02 +0000 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- torchrl/modules/llm/vllm_policy.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index fdba71cb63e..32bc4f77bde 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -116,6 +116,11 @@ def from_vllm( Transformers library. """ + try: + from vllm import SamplingParams + except ImportError: + raise ImportError("Please install `vllm` to use `from_vllm`.") + text_key: NestedKey = ("text",) token_key: NestedKey = ("tokens",) attention_mask_key: NestedKey = ("attention_mask",) From 01f172a097620a3275828ea8765156d1d8256e58 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 11 Mar 2025 11:50:56 +0000 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- torchrl/modules/llm/vllm_policy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchrl/modules/llm/vllm_policy.py b/torchrl/modules/llm/vllm_policy.py index 32bc4f77bde..daab91c76d0 100644 --- a/torchrl/modules/llm/vllm_policy.py +++ b/torchrl/modules/llm/vllm_policy.py @@ -21,6 +21,7 @@ TensorDictModuleBase, TensorDictSequential as Seq, ) +from tensordict.utils import _zip_strict from torchrl.data import LLMData @@ -367,8 +368,8 @@ def get_logprob(output): self.prompt_logprobs = torch.tensor( [ v[tid].logprob if v is not None else 0.0 - for v, tid in zip( - self.prompt_logprobs, self.prompt_token_ids, strict=True + for v, tid in _zip_strict( + self.prompt_logprobs, self.prompt_token_ids ) ] ) From 3fb911b9adb699c7cde2946081691c834fb1b8a1 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 11 Mar 2025 12:20:14 +0000 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- .github/unittest/linux_libs/scripts_llm/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_llm/environment.yml b/.github/unittest/linux_libs/scripts_llm/environment.yml index b0897796779..29f7dd98ac0 100644 --- a/.github/unittest/linux_libs/scripts_llm/environment.yml +++ b/.github/unittest/linux_libs/scripts_llm/environment.yml @@ -17,6 +17,6 @@ dependencies: - pyyaml - scipy - hydra-core - - transformers + - transformers<4.42.0 - datasets - vllm From 4e44ff8c6e308e24c50340f993bc51795a2e71da Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 11 Mar 2025 13:17:31 +0000 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- .../unittest/linux_libs/scripts_llm/environment.yml | 2 +- examples/rlhf/models/actor_critic.py | 2 +- torchrl/modules/tensordict_module/common.py | 13 +++++++++++-- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_llm/environment.yml b/.github/unittest/linux_libs/scripts_llm/environment.yml index 29f7dd98ac0..b0897796779 100644 --- a/.github/unittest/linux_libs/scripts_llm/environment.yml +++ b/.github/unittest/linux_libs/scripts_llm/environment.yml @@ -17,6 +17,6 @@ dependencies: - pyyaml - scipy - hydra-core - - transformers<4.42.0 + - transformers - datasets - vllm diff --git a/examples/rlhf/models/actor_critic.py b/examples/rlhf/models/actor_critic.py index b5be188fbd9..93c4d285b3e 100644 --- a/examples/rlhf/models/actor_critic.py +++ b/examples/rlhf/models/actor_critic.py @@ -34,4 +34,4 @@ def init_actor_critic(model_cfg, sys_cfg): critic = model.get_value_operator() critic_head = model.get_value_head() - return actor, VmapModule(critic), critic_head, base_model + return actor, VmapModule(critic, mock=True), critic_head, base_model diff --git a/torchrl/modules/tensordict_module/common.py b/torchrl/modules/tensordict_module/common.py index 2bd09e81e81..81d96fb7cec 100644 --- a/torchrl/modules/tensordict_module/common.py +++ b/torchrl/modules/tensordict_module/common.py @@ -436,7 +436,7 @@ class VmapModule(TensorDictModuleBase): >>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all() """ - def __init__(self, module: TensorDictModuleBase, vmap_dim=None): + def __init__(self, module: TensorDictModuleBase, vmap_dim=None, mock: bool = False): if not _has_functorch: raise ImportError("VmapModule requires torch>=2.0.") super().__init__() @@ -444,6 +444,7 @@ def __init__(self, module: TensorDictModuleBase, vmap_dim=None): self.out_keys = module.out_keys self.module = module self.vmap_dim = vmap_dim + self.mock = mock if torch.__version__ >= "2.0": self._vmap = torch.vmap else: @@ -451,6 +452,9 @@ def __init__(self, module: TensorDictModuleBase, vmap_dim=None): self._vmap = functorch.vmap + def mock_(self, value: bool = True): + self.mock = value + def forward(self, tensordict): # TODO: there is a risk of segfault if input is not a tensordict. # We should investigate (possibly prevent it c++ side?) @@ -458,7 +462,12 @@ def forward(self, tensordict): if vmap_dim is None: ndim = tensordict.ndim vmap_dim = ndim - 1 - td = self._vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict) + if self.mock: + td = torch.stack( + [self.module(_td) for _td in tensordict.unbind(vmap_dim)], vmap_dim + ) + else: + td = self._vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict) return tensordict.update(td)