diff --git a/README.md b/README.md
index ab820c19697..db5d4b3316a 100644
--- a/README.md
+++ b/README.md
@@ -23,6 +23,57 @@
**TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.
+## 🚀 What's New
+
+### LLM API - Complete Framework for Language Model Fine-tuning
+
+TorchRL now includes a comprehensive **LLM API** for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
+
+- 🤖 **Unified LLM Wrappers**: Seamless integration with Hugging Face models and vLLM inference engines
+- 💬 **Conversation Management**: Advanced `History` class for multi-turn dialogue with automatic chat template detection
+- 🛠️ **Tool Integration**: Built-in support for Python code execution, function calling, and custom tool transforms
+- 🎯 **Specialized Objectives**: GRPO (Group Relative Policy Optimization) and SFT loss functions optimized for language models
+- ⚡ **High-Performance Collectors**: Async data collection with distributed training support
+- 🔄 **Flexible Environments**: Transform-based architecture for reward computation, data loading, and conversation augmentation
+
+The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the [complete documentation](https://pytorch.org/rl/main/reference/llms.html) and [GRPO implementation example](https://github.com/pytorch/rl/tree/main/sota-implementations/grpo) to get started!
+
+
+ Quick LLM API Example
+
+```python
+from torchrl.envs.llm import ChatEnv
+from torchrl.modules.llm import TransformersWrapper
+from torchrl.objectives.llm import GRPOLoss
+from torchrl.collectors.llm import LLMCollector
+
+# Create environment with Python tool execution
+env = ChatEnv(
+ tokenizer=tokenizer,
+ system_prompt="You are an assistant that can execute Python code.",
+ batch_size=[1]
+).append_transform(PythonInterpreter())
+
+# Wrap your language model
+llm = TransformersWrapper(
+ model=model,
+ tokenizer=tokenizer,
+ input_mode="history"
+)
+
+# Set up GRPO training
+loss_fn = GRPOLoss(llm, critic, gamma=0.99)
+collector = LLMCollector(env, llm, frames_per_batch=100)
+
+# Training loop
+for data in collector:
+ loss = loss_fn(data)
+ loss.backward()
+ optimizer.step()
+```
+
+
+
## Key features
- 🐍 **Python-first**: Designed with Python as the primary language for ease of use and flexibility
@@ -516,6 +567,39 @@ And it is `functorch` and `torch.compile` compatible!
- various [recipes](https://github.com/pytorch/rl/blob/main/torchrl/trainers/helpers/models.py) to build models that
correspond to the environment being deployed.
+- **LLM API**: Complete framework for language model fine-tuning with unified wrappers for Hugging Face and vLLM backends,
+ conversation management with automatic chat template detection, tool integration (Python execution, function calling),
+ specialized objectives (GRPO, SFT), and high-performance async collectors. Perfect for RLHF, supervised fine-tuning,
+ and tool-augmented training scenarios.
+
+ Code
+
+ ```python
+ from torchrl.envs.llm import ChatEnv
+ from torchrl.modules.llm import TransformersWrapper
+ from torchrl.envs.llm.transforms import PythonInterpreter
+
+ # Create environment with tool execution
+ env = ChatEnv(
+ tokenizer=tokenizer,
+ system_prompt="You can execute Python code.",
+ batch_size=[1]
+ ).append_transform(PythonInterpreter())
+
+ # Wrap language model for training
+ llm = TransformersWrapper(
+ model=model,
+ tokenizer=tokenizer,
+ input_mode="history"
+ )
+
+ # Multi-turn conversation with tool use
+ obs = env.reset(TensorDict({"query": "Calculate 2+2"}, batch_size=[1]))
+ llm_output = llm(obs) # Generates response
+ obs = env.step(llm_output) # Environment processes response
+ ```
+
+
If you feel a feature is missing from the library, please submit an issue!
If you would like to contribute to new features, check our [call for contributions](https://github.com/pytorch/rl/issues/509) and our [contribution](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md) page.
@@ -792,6 +876,18 @@ A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blo
NA
+
+ LLM API (GRPO)
+
+ NA
+
+ +
+
+ +
+
+ NA
+
+
** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on
@@ -800,6 +896,7 @@ A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blo
and many more to come!
[Code examples](examples/) displaying toy code snippets and training scripts are also available
+- [LLM API & GRPO](sota-implementations/grpo) - Complete language model fine-tuning pipeline
- [RLHF](examples/rlhf)
- [Memory-mapped replay buffers](examples/torchrl_features)
diff --git a/docs/source/_static/img/llm-data.svg b/docs/source/_static/img/llm-data.svg
new file mode 100644
index 00000000000..ee76e85e5ba
--- /dev/null
+++ b/docs/source/_static/img/llm-data.svg
@@ -0,0 +1,5 @@
+
diff --git a/docs/source/_static/img/llm-env.png b/docs/source/_static/img/llm-env.png
new file mode 100644
index 00000000000..df6c7401cef
Binary files /dev/null and b/docs/source/_static/img/llm-env.png differ
diff --git a/docs/source/reference/llms.rst b/docs/source/reference/llms.rst
index 402e18ffa97..c162cae7c0c 100644
--- a/docs/source/reference/llms.rst
+++ b/docs/source/reference/llms.rst
@@ -1,19 +1,491 @@
.. currentmodule:: torchrl
-LLM interface
+LLM Interface
=============
.. _ref_llms:
-TorchRL offers a set of tools for LLM post-training, as well as some examples for training or setup.
+TorchRL provides a comprehensive framework for LLM post-training and fine-tuning. The LLM API is built around five core concepts that work
+together to create a complete reinforcement learning pipeline for language models:
+
+1. **Data Representation** (`Data Structures`_): The foundation for handling conversations, text parsing, and LLM
+ output classes. This includes the :class:`~torchrl.data.llm.History` class for managing conversation context and structured output classes for
+ tokens, log-probabilities, and text.
+
+2. **LLM Wrapper API** (`Modules`_): Unified interfaces for different LLM backends, including :class:`~torchrl.modules.llm.TransformersWrapper` for
+ Hugging Face models and :class:`~torchrl.modules.llm.vLLMWrapper` for vLLM inference. These wrappers provide consistent input/output formats across
+ different backends and an integrated interface for loss computation, data storage, grading, weight synchronization, etc.
+
+3. **Environments** (`Environments`_): The orchestration layer that manages data loading, tool execution, reward computation, and formatting. This includes
+ :class:`~torchrl.envs.llm.ChatEnv` for conversation management, dataset environments, and various transforms for tool integration.
+
+4. **Objectives** (`Objectives`_): Specialized loss functions for LLM training, including :class:`~torchrl.objectives.llm.GRPOLoss` for Group Relative
+ Policy Optimization and :class:`~torchrl.objectives.llm.SFTLoss` for supervised fine-tuning.
+
+5. **Collectors** (`Collectors`_): Collectors are used to collect data from the environment and store it in a format that can be used for training. This includes
+ :class:`~torchrl.collectors.llm.LLMCollector` for collecting data from the environment and :class:`~torchrl.collectors.llm.RayLLMCollector` for collecting
+ data in distributed settings using Ray.
+
+These components work together to create a complete pipeline: environments load and format data, LLM wrappers handle inference, data structures maintain
+conversation context, and objectives compute training losses. The modular design allows you to mix and match components based on your specific use case.
+
+A complete example of how to use the LLM API can be found in the `sota-implementations/grpo/` directory. The training orchestration involves three main components:
+
+- The Data Collector: holds a reference to the environment and the inference model or engine. It collects data, puts it in the buffer, and handles weight updates.
+- The Replay Buffer: stores the collected data and executes any pre or post-processing steps. These may include:
+ - Advantage estimation with Monte-Carlo based method (using the :class:`~torchrl.objectives.llm.MCAdvantage` transform);
+ - Grading of the outputs;
+ - Logging etc.
+- The trainer: handles the training loop, including the optimization step, serialization, logging and weight updates initialization.
+
+.. warning:: The LLM API is still under development and may change in the future. Feedback, issues and PRs are welcome!
+
+Data Structures
+---------------
+
+The data representation layer provides the foundation for handling conversations and LLM outputs in a structured way.
+
+History Class
+~~~~~~~~~~~~~
+
+The :class:`~torchrl.data.llm.History` class is a TensorClass version of the chat format usually found in transformers
+(see `Hugging Face chat documentation `_).
+It provides a comprehensive API for managing conversation data with features including:
+
+- **Text parsing and formatting**: Convert between text and structured conversation format using :meth:`~torchrl.data.llm.chat.History.from_text`
+ and :meth:`~torchrl.data.llm.chat.History.apply_chat_template`
+- **Dynamic conversation building**: Append and extend conversations with :meth:`~torchrl.data.llm.chat.History.append` and
+ :meth:`~torchrl.data.llm.chat.History.extend` methods
+- **Multi-model support**: Automatic template detection for various model families (Qwen, DialoGPT, Falcon, DeepSeek, etc.)
+- **Assistant token masking**: Identify which tokens were generated by the assistant for reinforcement learning applications
+- **Tool calling support**: Handle function calls and tool responses in conversations
+- **Batch operations**: Efficient tensor operations for processing multiple conversations simultaneously.
+
+.. currentmodule:: torchrl.data.llm
+
+.. autosummary::
+ :toctree: generated/
+ :template: rl_template.rst
+
+ History
+ ContentBase
+ LLMData
+
+Supported Model Families
+^^^^^^^^^^^^^^^^^^^^^^^^
+
+We currently support the following model families for string to History parsing or assistant token masking:
+
+- **Qwen family** (e.g., `Qwen/Qwen2.5-0.5B`): Custom template with full tool calling support
+- **DialoGPT family** (e.g., `microsoft/DialoGPT-medium`): Custom template for conversation format
+- **Falcon family** (e.g., `tiiuae/falcon-7b-instruct`): Custom template for instruction format
+- **DeepSeek family** (e.g., `deepseek-ai/deepseek-coder-6.7b-base`): Custom template with native format
+
+Other models are supported, but you will need to provide a custom template for them.
+LLAMA, Mistral, OPT, GPT, MPT, BLOOM, Pythia, Phi, etc. will use the default `chatml_format` template.
+
+Usage
+^^^^^
+
+.. code-block:: python
+
+ >>> from torchrl.data.llm.chat import History
+ >>> from transformers import AutoTokenizer
+ >>>
+ >>> # Create a conversation history
+ >>> history = History.from_chats([[
+ ... {"role": "user", "content": "Hello"},
+ ... {"role": "assistant", "content": "Hi there!"},
+ ... {"role": "user", "content": "How are you?"},
+ ... {"role": "assistant", "content": "I'm doing well, thanks!"}
+ ... ]])
+ >>>
+ >>> # Load any supported tokenizer
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+ >>>
+ >>> # Apply chat template with assistant token masking
+ >>> result = history.apply_chat_template(
+ ... chat_template_name="qwen",
+ ... add_generation_prompt=False,
+ ... return_dict=True,
+ ... return_assistant_tokens_mask=True,
+ ... )
+ >>>
+ >>> # The result contains an assistant_masks tensor
+ >>> assistant_masks = result["assistant_masks"]
+ >>> print(f"Assistant tokens: {assistant_masks.sum().item()}")
+
+Adding Custom Templates
+^^^^^^^^^^^^^^^^^^^^^^^
+
+You can add custom chat templates for new model families using the :func:`torchrl.data.llm.chat.add_chat_template` function.
+
+.. autofunction:: torchrl.data.llm.chat.add_chat_template
+
+Usage Examples
+^^^^^^^^^^^^^^
+
+Adding a Llama Template
+"""""""""""""""""""""""
+
+.. code-block:: python
+
+ >>> from torchrl.data.llm.chat import add_chat_template, History
+ >>> from transformers import AutoTokenizer
+ >>>
+ >>> # Define the Llama chat template
+ >>> llama_template = '''
+ ... {% for message in messages %}
+ ... {%- if message['role'] == 'user' %}
+ ... {{ '[INST] ' + message['content'] + ' [/INST]' }}
+ ... {%- elif message['role'] == 'assistant' %}
+ ... {% generation %}{{ message['content'] + ' ' }}{% endgeneration %}
+ ... {%- endif %}
+ ... {% endfor %}
+ ... {%- if add_generation_prompt %}
+ ... {% generation %}{{ ' ' }}{% endgeneration %}
+ ... {%- endif %}
+ ... '''
+ >>>
+ >>> # Define the inverse parser for Llama format
+ >>> def parse_llama_text(text: str) -> History:
+ ... import re
+ ... pattern = r'\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?) '
+ ... matches = re.findall(pattern, text, re.DOTALL)
+ ... messages = []
+ ... for user_content, assistant_content in matches:
+ ... messages.append(History(role="user", content=user_content.strip()))
+ ... messages.append(History(role="assistant", content=assistant_content.strip()))
+ ... return lazy_stack(messages)
+ >>>
+ >>> # Add the template with auto-detection
+ >>> add_chat_template(
+ ... template_name="llama",
+ ... template=llama_template,
+ ... inverse_parser=parse_llama_text,
+ ... model_family_keywords=["llama", "meta-llama"]
+ ... )
+ >>>
+ >>> # Now you can use it with auto-detection
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+ >>> history = History.from_chats([[
+ ... {"role": "user", "content": "Hello"},
+ ... {"role": "assistant", "content": "Hi there!"}
+ ... ]])
+ >>>
+ >>> # Auto-detection will use the llama template
+ >>> result = history.apply_chat_template(
+ ... tokenizer=tokenizer,
+ ... add_generation_prompt=False,
+ ... return_dict=True,
+ ... return_assistant_tokens_mask=True,
+ ... )
+
+Testing Your Custom Templates
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+When adding custom templates, you should test them to ensure they work correctly. Here are the recommended tests:
+
+Assistant Token Masking Test
+""""""""""""""""""""""""""""
+
+Test that your template supports assistant token masking:
+
+.. code-block:: python
+
+ import pytest
+ from torchrl.data.llm.chat import History, add_chat_template
+ from transformers import AutoTokenizer
+
+ def test_my_model_assistant_masking():
+ """Test that your model supports assistant token masking."""
+ # Add your template first
+ add_chat_template(
+ template_name="my_model",
+ template="your_template_here",
+ model_family_keywords=["my_model"]
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained("your/model/name")
+ history = History.from_chats([[
+ {'role': 'user', 'content': 'Hello'},
+ {'role': 'assistant', 'content': 'Hi there!'}
+ ]])
+
+ result = history.apply_chat_template(
+ tokenizer=tokenizer,
+ chat_template_name="my_model",
+ add_generation_prompt=False,
+ return_dict=True,
+ return_assistant_tokens_mask=True,
+ )
+
+ # Verify assistant mask is present
+ assert 'assistant_masks' in result
+ assert result['assistant_masks'].shape[0] == 1, "Should have batch dimension of 1"
+ assert result['assistant_masks'].shape[1] > 0, "Should have sequence length > 0"
+
+ # Verify some assistant tokens are masked
+ assistant_token_count = result['assistant_masks'].sum().item()
+ assert assistant_token_count > 0, "Should have assistant tokens masked"
+ print(f"✓ {assistant_token_count} assistant tokens masked")
+
+Template Equivalence Test
+"""""""""""""""""""""""""
+
+Test that your custom template produces the same output as the model's default template (except for masking):
+
+.. code-block:: python
+
+ def test_my_model_template_equivalence():
+ """Test that your template matches the model's default template."""
+ tokenizer = AutoTokenizer.from_pretrained("your/model/name")
+ history = History.from_chats([[
+ {'role': 'user', 'content': 'Hello'},
+ {'role': 'assistant', 'content': 'Hi there!'},
+ {'role': 'user', 'content': 'How are you?'},
+ {'role': 'assistant', 'content': 'I\'m good, thanks!'},
+ ]])
+
+ # Get output with model's default template
+ try:
+ default_out = history.apply_chat_template(
+ tokenizer=tokenizer,
+ add_generation_prompt=False,
+ chat_template=tokenizer.chat_template,
+ tokenize=False,
+ )
+ except Exception as e:
+ default_out = None
+ print(f"[WARN] Could not get default template: {e}")
+
+ # Get output with your custom template
+ custom_out = history.apply_chat_template(
+ tokenizer=tokenizer,
+ add_generation_prompt=False,
+ chat_template_name="my_model",
+ tokenize=False,
+ )
+
+ if default_out is not None:
+ # Normalize whitespace for comparison
+ import re
+ def norm(s):
+ return re.sub(r"\s+", " ", s.strip())
+
+ assert norm(default_out) == norm(custom_out), (
+ f"Custom template does not match default!\n"
+ f"Default: {default_out}\nCustom: {custom_out}"
+ )
+ print("✓ Template equivalence verified")
+ else:
+ print("[INFO] Skipped equivalence check (no default template available)")
+
+Inverse Parsing Test
+""""""""""""""""""""
+
+If you provided an inverse parser, test that it works correctly:
+
+.. code-block:: python
+
+ def test_my_model_inverse_parsing():
+ """Test that your inverse parser works correctly."""
+ history = History.from_chats([[
+ {'role': 'user', 'content': 'Hello'},
+ {'role': 'assistant', 'content': 'Hi there!'}
+ ]])
+
+ # Format using your template
+ formatted = history.apply_chat_template(
+ tokenizer=tokenizer,
+ chat_template_name="my_model",
+ add_generation_prompt=False,
+ tokenize=False,
+ )
+
+ # Parse back using your inverse parser
+ parsed = History.from_text(formatted, chat_template_name="my_model")
+
+ # Verify the parsing worked
+ assert parsed.role == history.role
+ assert parsed.content == history.content
+ print("✓ Inverse parsing verified")
+
+LLM Wrapper API
+~~~~~~~~~~~~~~~
+
+The LLM wrapper API provides unified interfaces for different LLM backends, ensuring consistent input/output formats across training and inference pipelines. The main wrappers are :class:`~torchrl.modules.llm.TransformersWrapper` for Hugging Face models and :class:`~torchrl.modules.llm.vLLMWrapper` for vLLM inference.
+
+**Data Structure Classes**
+
+The wrappers use structured :class:`~tensordict.TensorClass` objects to represent different aspects of LLM data:
+
+- **:class:`~torchrl.modules.llm.policies.Text`**: Contains text data with `prompt`, `response`, and `full` fields
+- **:class:`~torchrl.modules.llm.policies.ChatHistory`**: Contains :class:`~torchrl.data.llm.History` objects with `prompt`, `response`, and `full` fields
+- **:class:`~torchrl.modules.llm.policies.Tokens`**: Contains tokenized data with `prompt`, `response`, and `full` fields
+- **:class:`~torchrl.modules.llm.policies.LogProbs`**: Contains log probabilities with `prompt`, `response`, and `full` fields
+- **:class:`~torchrl.modules.llm.policies.Masks`**: Contains attention and assistant masks
+
+**API Flow**
+
+The wrappers operate in two distinct modes:
+
+**Generation Mode (`generate=True`)**:
+- **Input**: Reads from `prompt` fields (e.g., `history.prompt`, `text.prompt`, `tokens.prompt`)
+- **Output**: Writes to both `response` and `full` fields
+ - `response`: Contains only the newly generated content
+ - `full`: Contains the complete sequence (prompt + response)
+
+**Log-Probability Mode (`generate=False`)**:
+- **Input**: Reads from `full` fields (e.g., `history.full`, `text.full`, `tokens.full`)
+- **Output**: Writes log probabilities to the corresponding `full` fields
+
+**LLM-Environment Interaction Loop**
+
+.. figure:: /_static/img/llm-env.png
+ :alt: LLM-Environment interaction loop
+ :align: center
+ :width: 80%
+
+ LLM-Environment interaction: the LLM generates a response, the environment updates the conversation, and transforms can inject new messages or tools.
+
+In a typical RL or tool-augmented setting, the LLM and environment interact in a loop:
+
+1. **LLM Generation**: The LLM wrapper receives a `prompt` (the current conversation history), generates a `response`, and outputs a `full` field
+ containing the concatenation of the prompt and response.
+2. **Environment Step**: The environment takes the `full` field and makes it the next `prompt` for the LLM. This ensures that the conversation
+ context grows with each turn. See :ref:`ref_env_llm_step` for more details.
+3. **Transforms**: Before the next LLM step, transforms can modify the conversation—for example, by inserting a new user message, a tool call,
+ or a reward annotation.
+4. **Repeat**: This process repeats for as many turns as needed, enabling multi-turn dialogue, tool use, and RL training.
+
+This design allows for flexible augmentation of the conversation at each step, supporting advanced RL and tool-use scenarios.
+
+A typical pseudocode loop:
+
+.. code-block:: python
+
+ # Get the first prompt out of an initial query
+ obs = env.reset(TensorDict({"query": "Hello!"}, batch_size=env.batch_size, device=env.device))
+ while not done:
+ # LLM generates a response given the current prompt
+ llm_output = llm(obs)
+ # Environment steps: creates a ("next", "history") field with the new prompt (from the previous `"full"` field)
+ obs = env.step(llm_output)
+
+**Integration with History**
+
+When using `input_mode="history"`, the wrapper integrates seamlessly with the :class:`~torchrl.data.llm.History` class:
+
+- **Input**: Takes a :class:`~torchrl.modules.llm.policies.ChatHistory` object containing a History in the `prompt` field
+- **Generation**: Applies chat templates to convert History to tokens, generates response, then parses the full text back into a History object
+- **Output**: Returns a ChatHistory with:
+ - `prompt`: Original conversation history
+ - `response`: New History object containing only the assistant's response
+ - `full`: Complete conversation history with the new response appended
+
+This design allows for natural conversation flow where each generation step extends the conversation history, making it ideal for multi-turn
+dialogue systems.
+
+
+Prompt vs. Response and padding
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+.. figure:: /_static/img/llm-data.svg
+ :alt: LLM output data format (Tokens, Masks, Padded vs. Sparse)
+ :align: center
+ :width: 80%
+
+ Structure of LLM outputs: padded vs. sparse representations for Tokens, LogProbs, and Masks.
+
+The diagram above illustrates the structure of the main output classes used in TorchRL's LLM API:
+
+- **Tokens** (and by extension, **LogProbs**):
+ - *Padded* format: All sequences in a batch are padded to the same length (with a special pad token), making them suitable for tensor operations. The prompt and response are concatenated to form `tokens.full`, and masks indicate valid vs. padded positions.
+ - *Sparse* format: Each sequence retains its original length (no padding), represented as lists of tensors. This is more memory-efficient for variable-length data.
+- **Masks**: Two main masks are shown:
+ - `mask.attention_mask_all` marks valid (non-pad) tokens.
+ - `mask.assistant_mask_all` marks which tokens were generated by the assistant (useful for RLHF and SFT training).
+- **Text**: Not shown in detail, as it is simply the decoded string representation of the prompt, response, or full sequence.
+
+This format ensures that all LLM outputs (Tokens, LogProbs, Masks, Text) are consistent and easy to manipulate, regardless of whether you use padded or sparse batching.
+
+In general, we recommend working with unpadded data, as it is more memory-efficient and easier to manipulate.
+For instance, when collecting multiple padded elements from the buffer, it may be hard to clearly understand how to re-pad them
+to combine them in a cohesive batch. Working with unpadded data is more straightforward.
+
+Modules
+-------
+
+The LLM wrapper API provides unified interfaces for different LLM backends, ensuring consistent input/output formats across training and inference pipelines.
+
+Wrappers
+~~~~~~~~
+
+The main goal of these primitives is to:
+
+- Unify the input/output data format across training and inference pipelines
+- Unify the input/output data format across backends (to be able to use different backends across losses and collectors)
+- Provide appropriate tooling to construct these objects in typical RL settings (resource allocation, async execution, weight update, etc.)
+
+.. currentmodule:: torchrl.modules.llm
+
+.. autosummary::
+ :toctree: generated/
+ :template: rl_template.rst
+
+ LLMWrapperBase
+ TransformersWrapper
+ vLLMWrapper
+ ChatHistory
+ Text
+ LogProbs
+ Masks
+ Tokens
+
+Utils
+^^^^^
+
+.. currentmodule:: torchrl.modules.llm
+
+.. autosummary::
+ :toctree: generated/
+ :template: rl_template.rst
+
+ LLMOnDevice
+ make_vllm_worker
+ stateless_init_process_group
+ vLLMWorker
Collectors
----------
-TorchRL offers specialized collector classes (:class:`~torchrl.collectors.llm.LLMCollector` and :class:`~torchrl.collectors.llm.RayLLMCollector`) that are tailored for LLM
-use cases. We also provide dedicated updaters for some inference engines.
+.. _Collectors:
+
+TorchRL offers specialized collector classes (:class:`~torchrl.collectors.llm.LLMCollector` and :class:`~torchrl.collectors.llm.RayLLMCollector`)
+that are tailored for LLM use cases. We also provide dedicated updaters for some inference engines.
+
+See :ref:`ref_collectors` for more details on the collector API. In brief, the idea of a collector is to isolate the inference part of the pipeline
+in a dedicated class.
+A collector usually takes as input a policy and an environment, and alternate between running one and the other.
+In "classical" settings, the policy is similar to the policy being trained (with some optional extra-exploration). In the context of LLM fine-tuning,
+the policy will usually be a specialized inference engine, such as a vLLM server.
+Collectors are defined by the following parameters and features:
+
+- **Sync/Async**: Whether the collector should run in sync or async mode.
+ In sync mode, the collector will run the inference step in alternate with the optimization/training step.
+ In async mode, the collector will run the inference step in parallel with the optimization/training step.
+ A replay buffer can be passed to the collector, in such a way that the collector can directly write to it.
+ In other cases, the collector can be iterated over to collect data.
+- **Steps**: A collector is built with a certain number of steps budget, as well as a number of steps to be
+ included in each batch yield during collection.
+- **Weight Updater**: Weight updaters are the classes that update the policy weights. Isolating the weight update
+ in a dedicated class allows to easily implement different weight update strategies depending on the policy specification.
+
+Policy Version Tracking
+~~~~~~~~~~~~~~~~~~~~~~~
-LLM Collectors allow to track the version of the policy, which is useful for some use cases.
+LLM Collectors also allow to track the version of the policy, which is useful for some use cases.
This is done by adding a :class:`~torchrl.envs.llm.transforms.PolicyVersion` transform to the environment, which is
then incremented by the collector after each weight update. To do this, one either provides the stateful version of the
transform, or a boolean to the collector constructor.
@@ -43,64 +515,70 @@ transform, or a boolean to the collector constructor.
LLMCollector
RayLLMCollector
+Environments
+------------
-Data structures
----------------
+The environment layer orchestrates data loading, tool execution, reward computation, and formatting. When fine-tuning an LLM using TorchRL, the environment is a
+crucial component of the inference pipeline, alongside the policy and collector.
-To handle text-based data structures (such as conversations etc.), we offer a few data structures dedicated to carrying
-data for LLM post-training.
+ChatEnv
+~~~~~~~
-.. currentmodule:: torchrl.data.llm
+:class:`~torchrl.envs.llm.ChatEnv` serves as a blank canvas for LLM environments - it's a basic tool designed to be extended with transforms that add
+specific functionality. The base ChatEnv provides the fundamental structure for managing conversation state using the
+:class:`~torchrl.data.llm.History` format, but it's intentionally minimal to allow maximum flexibility.
-.. autosummary::
- :toctree: generated/
- :template: rl_template.rst
+Core Functionality
+^^^^^^^^^^^^^^^^^^
- History
- ContentBase
- LLMData
+ChatEnv operates in three main modes:
+- **History mode**: Uses :class:`~torchrl.data.llm.History` objects for conversation management
+- **Text mode**: Uses simple text strings for input/output
+- **Tokens mode**: Uses tokenized data for input/output
-Environments
-------------
+The environment maintains conversation state by:
+- **Reset**: Initializes a new conversation with an optional system prompt
+- **Step**: Takes the LLM's response and updates the conversation history, preparing the next prompt
-When fine-tuning an LLM using TorchRL, the environment is a crucial component of the inference pipeline, alongside the
-policy and collector. Environments manage operations that are not handled by the LLM itself, such as interacting with
-tools, loading prompts from datasets, computing rewards (when necessary), and formatting data.
+Transform-Based Architecture
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-Therefore, the fundamental structure of an LLM post-training pipeline is:
+Transforms are the main way to extend ChatEnv with specific capabilities:
-- A policy that wraps the LLM and the LLM only
-- An environment that handles the world around the LLM:
- - Loading data (through :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer`)
- - Formatting data (through :class:`~torchrl.envs.llm.transforms.TemplateTransform`)
- - Executing tools (through :class:`~torchrl.envs.llm.transforms.PythonInterpreter` or :class:`~torchrl.envs.llm.transforms.MCPToolTransform`)
- - Computing rewards online, if needed (through :class:`~torchrl.envs.llm.transforms.KLRewardTransform`)
-- A data collector that takes the policy (the LLM) and the environment, and handles the inference part of the pipeline:
- - Running reset, step and gathering actions;
- - Yielding the data in a consistent format - or populating a buffer;
- - Updating the policy weights (through :class:`~torchrl.collectors.WeightUpdaterBase` classes)
-- A replay buffer that stores the data collected using the collector
-- A loss that takes the LLM's output and returns a loss (through :class:`~torchrl.objectives.llm.GRPOLoss` for example)
+- **Reward computation**: :class:`~torchrl.envs.llm.transforms.KLRewardTransform` for KL divergence rewards
+- **Tool execution**: :class:`~torchrl.envs.llm.transforms.PythonInterpreter` for Python code
+ execution, :class:`~torchrl.envs.llm.transforms.MCPToolTransform` for general tool calling.
+- **Data loading**: :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` for loading prompts from datasets
+- **Thinking prompts**: :class:`~torchrl.envs.llm.transforms.AddThinkingPrompt` for chain-of-thought reasoning
+- **Policy tracking**: :class:`~torchrl.envs.llm.transforms.PolicyVersion` for version control
+- **Step counting**: Built-in step tracking and reset management using :class:`~torchrl.envs.transforms.StepCounter`.
-These elements are presented in the GRPO scripts in the `sota-implementations/llm` directory.
+Integration with LLM Wrappers
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-The design of environments in TorchRL allows for flexibility and modularity. By framing tasks as environments, users can
-easily extend or modify existing environments using transforms. This approach enables the isolation of individual
-components within specific :class:`~torchrl.envs.EnvBase` or :class:`~torchrl.envs.Transform` subclasses, making it
-simpler to augment or alter the environment logic.
+.. _ref_env_llm_step:
-Available Environment Classes and Utilities
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+ChatEnv is designed to work seamlessly with both :class:`~torchrl.modules.llm.TransformersWrapper` and :class:`~torchrl.modules.llm.vLLMWrapper`.
+The environment handles the conversation state management while the wrapper handles the actual LLM inference, creating a clean separation of concerns.
-TorchRL provides various environment classes and utilities for working with LLMs, including:
+On each call to `step`, the environment:
+
+- Takes the LLM's output, specifically the `full` field, which contains the entire conversation so far, including the new response (e.g., `history.full`, `text.full`, `tokens.full`).
+- Sets this `full` field as the new `prompt` for the next LLM step (e.g., `td["next", "history"].prompt`, `td["next", "text"].prompt`, `td["next", "tokens"].prompt`).
+- Optionally, applies transforms to insert new user messages, tool calls, or other modifications to the conversation before the next LLM step to refine the prompt.
+
+This mechanism enables seamless multi-turn interactions and supports complex workflows such as tool use and reward shaping.
+
+Task-Specific Environments
+~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+We provide a few task-specific environments, such as :class:`~torchrl.envs.llm.GSM8KEnv` for the GSM8K dataset,
+:class:`~torchrl.envs.llm.IFEvalEnv` for the IFEval dataset, and :class:`~torchrl.envs.llm.MLGymEnv` for MLGym integration.
+
+These environments wrap a :class:`~torchrl.envs.llm.ChatEnv` and add a :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` transform
+(plus an optional reward parsing transform) in a :class:`~torchrl.envs.TransformedEnv` class.
-- Various environment classes (:class:`~torchrl.envs.llm.ChatEnv`, :class:`~torchrl.envs.llm.DatasetChatEnv`,
- :class:`~torchrl.envs.llm.GSM8KEnv`, etc.)
-- Utility functions (:class:`~torchrl.envs.make_gsm8k_env`, :class:`~torchrl.envs.make_mlgym`, etc.)
-- Transforms and other supporting classes (:class:`~torchrl.envs.KLRewardTransform`,
- :class:`~torchrl.envs.TemplateTransform`, :class:`~torchrl.envs.Tokenizer`, etc.)
-These components can be used to create customized environments tailored to specific use cases and requirements.
.. currentmodule:: torchrl.envs.llm
@@ -192,73 +670,173 @@ Similarly, environments that load data from a dataset are just special instances
augmented with a :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` transforms (and some dedicated reward parsing
transforms).
-.. currentmodule:: torchrl.envs.llm.transforms
+Designing Reward Transforms
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
-.. autosummary::
- :toctree: generated/
- :template: rl_template.rst
+When designing reward transforms for LLM environments, several key considerations must be
+addressed to ensure proper integration with the training pipeline.
+The examples of :class:`~torchrl.envs.llm.GSM8KRewardParser` and
+:class:`~torchrl.envs.llm.IfEvalScorer` provide excellent templates for reward transform design.
- DataLoadingPrimer
- KLRewardTransform
- RetrieveLogProb
- MCPToolTransform
- BrowserTransform
- PythonInterpreter
- PolicyVersion
- TemplateTransform
- Tokenizer
- as_nested_tensor
- as_padded_tensor
+**Reward Shape Requirements**
-Modules
--------
+The reward tensor must have the same number of dimensions as the logits, which is typically
+two more dimensions than the environment batch size:
-The :ref:`~torchrl.modules.llm` section provides a set of wrappers and utility functions for popular training and
-inference backends. The main goal of these primitives is to:
+- **Sparse rewards**: Shape ``(*bsz, 1, 1)`` - single reward per sequence
+- **Dense rewards**: Shape ``(*bsz, num_tokens, 1)`` - per-token rewards
-- Unify the input / output data format across training and inference pipelines;
-- Unify the input / output data format across backends (to be able to use different backends across losses and
- collectors, for instance)
-- Give appropriate tooling to construct these objects in typical RL settings (resource allocation, async execution,
- weight update, etc.)
+This shape requirement ensures compatibility with the loss computation pipeline.
+For example, in the GSM8K reward parser:
-Wrappers
-~~~~~~~~
+.. code-block:: python
-.. currentmodule:: torchrl.modules.llm
+ # Rewards need to have shape broadcastable to [batch x tokens x 1]
+ tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))
-.. autosummary::
- :toctree: generated/
- :template: rl_template.rst
+**Done State Management**
- TransformersWrapper
- vLLMWrapper
+It is crucial to properly manage the done state to prevent endless generation. Common strategies include:
-Utils
-~~~~~
+1. **Completion-based termination**: Set done when the response is complete (e.g., ``History.complete=True``)
+2. **Content-based termination**: Set done when specific content is detected (e.g., ```` blocks)
+3. **Step-based termination**: Use :class:`~torchrl.envs.transforms.StepCounter` for predetermined step limits
-.. currentmodule:: torchrl.modules.llm
+Example from IFEvalScorer:
+
+.. code-block:: python
+
+ if self.set_done_if_answer and bool(answer_blocks):
+ next_tensordict.set("done", torch.ones(...))
+ next_tensordict.set("terminated", torch.ones(...))
+
+**Input Mode Handling**
+
+Reward transforms must handle different input modes correctly:
+
+- **History mode**: Extract text from ``("history", "full")`` or ``("history", "response")``
+- **Text mode**: Use text directly from ``("text", "full")`` or ``("text", "response")``
+- **Tokens mode**: Decode tokens from ``("tokens", "full")`` or ``("tokens", "response")``
+
+The GSM8K reward parser demonstrates this pattern:
+
+.. code-block:: python
+
+ if input_mode == "history":
+ responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
+ if hasattr(responses, "content"):
+ text_completion = responses.content
+ elif input_mode == "text":
+ text_completion = responses
+ elif input_mode == "tokens":
+ text_completion = self.tokenizer.decode(responses.flatten(0, 1).tolist())
+
+**Specification Management**
+
+Accurate specification of reward and observation specs is essential for proper environment initialization. Both GSM8K and IFEval provide good examples:
+
+.. code-block:: python
+
+ def transform_reward_spec(self, reward_spec: Composite) -> Composite:
+ shape = reward_spec.shape + (1, 1)
+ reward_spec.update(
+ Composite(
+ reward_answer=Unbounded(shape),
+ reward_think=Unbounded(shape),
+ reward_right=Unbounded(shape),
+ reward_contained=Unbounded(shape),
+ reward=Unbounded(shape),
+ success=Unbounded(shape, dtype=torch.bool),
+ )
+ )
+ return reward_spec
+
+**Batch Processing Considerations**
+
+For efficient processing, handle batched data appropriately:
+
+1. **Flatten batch dimensions**: Use ``tensordict.view(-1)`` for processing
+2. **Reshape results**: Restore original batch structure after processing
+3. **Handle variable-length sequences**: Use proper padding and masking
+
+**Reward Aggregation Strategies**
+
+Consider different reward aggregation approaches:
+
+1. **Simple aggregation**: Sum or average multiple reward components
+2. **Weighted aggregation**: Apply different weights to different components
+3. **Conditional rewards**: Base rewards on specific conditions or thresholds
+
+The IFEvalScorer demonstrates a sophisticated aggregation strategy:
+
+.. code-block:: python
+
+ def default_reward_aggregator(self, score: IFEvalScoreData, ...):
+ # Format score (max 1.0)
+ format_score = (format_components * weights).sum(dim=-1, keepdim=True)
+
+ # Structure score (max 1.0)
+ structure_score = think_score + answer_score
+
+ # Completion bonus (max 0.2)
+ completion_bonus = float(complete) * 0.2
+
+ return format_score + structure_score + completion_bonus
+
+**Post-Processing in Replay Buffers**
+
+Rewards can also be computed after the fact by appending transforms to the replay buffer. However, done state capture must remain in the environment transform since it needs to occur on-the-fly during data collection.
+
+**Error Handling and Robustness**
+
+Implement robust error handling for parsing failures:
+
+.. code-block:: python
+
+ try:
+ cot, potential_answer = self.extract_tags(compl)
+ except ET.ParseError:
+ cot, potential_answer = ("", "")
+
+**Performance Considerations**
+
+1. **Avoid redundant computations**: Cache parsed results when possible
+2. **Use efficient text processing**: Leverage regex or XML parsing as appropriate
+3. **Minimize memory allocations**: Reuse tensors and avoid unnecessary copies
+
+By following these design principles, reward transforms can be effectively integrated into the LLM training pipeline while maintaining performance and reliability.
+
+.. currentmodule:: torchrl.envs.llm.transforms
.. autosummary::
:toctree: generated/
:template: rl_template.rst
- CategoricalSequential
- LLMOnDevice
- make_vllm_worker
- stateless_init_process_group
- vLLMWorker
+ AddThinkingPrompt
+ BrowserTransform
+ DataLoadingPrimer
+ KLComputation
+ KLRewardTransform
+ MCPToolTransform
+ PolicyVersion
+ PythonInterpreter
+ RetrieveKL
+ RetrieveLogProb
+ TemplateTransform
+ Tokenizer
+ as_nested_tensor
+ as_padded_tensor
Objectives
----------
-LLM post training require some appropriate versions of the losses implemented in TorchRL.
+LLM post-training requires specialized loss functions that are adapted to the unique characteristics of language models.
GRPO
~~~~
The :class:`~torchrl.objectives.llm.GRPOLoss` class is a thin wrapper around the :class:`~torchrl.objectives.PPOLoss` class
-that codes the LLM-specific functionnalities.
+that codes the LLM-specific functionalities.
.. currentmodule:: torchrl.objectives.llm
@@ -270,9 +848,8 @@ that codes the LLM-specific functionnalities.
GRPOLossOutput
MCAdvantage
-
SFT
-~~~
+^^^
.. currentmodule:: torchrl.objectives.llm
diff --git a/sota-implementations/expert-iteration/ei_utils.py b/sota-implementations/expert-iteration/ei_utils.py
index 179ec4d8aa2..ec061b5f318 100644
--- a/sota-implementations/expert-iteration/ei_utils.py
+++ b/sota-implementations/expert-iteration/ei_utils.py
@@ -104,7 +104,7 @@ def get_train_model(
param.data = param.data.to(model_dtype)
if chat_template_name is not None:
- from torchrl.data.llm.chat import _CHAT_TEMPLATES
+ from torchrl.data.llm.history import _CHAT_TEMPLATES
chat_template = _CHAT_TEMPLATES[chat_template_name]
train_tokenizer.chat_template = chat_template
diff --git a/sota-implementations/expert-iteration/expert-iteration-async.py b/sota-implementations/expert-iteration/expert-iteration-async.py
index e8506a85c99..5cb62b319ba 100644
--- a/sota-implementations/expert-iteration/expert-iteration-async.py
+++ b/sota-implementations/expert-iteration/expert-iteration-async.py
@@ -15,7 +15,7 @@
from torchrl import torchrl_logger
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
-from torchrl.data.llm.chat import History
+from torchrl.data.llm.history import History
from torchrl.record.loggers.wandb import WandbLogger
try:
diff --git a/sota-implementations/expert-iteration/expert-iteration-sync.py b/sota-implementations/expert-iteration/expert-iteration-sync.py
index 556074f3469..34565228754 100644
--- a/sota-implementations/expert-iteration/expert-iteration-sync.py
+++ b/sota-implementations/expert-iteration/expert-iteration-sync.py
@@ -15,7 +15,7 @@
from torchrl import torchrl_logger
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
-from torchrl.data.llm.chat import History
+from torchrl.data.llm.history import History
from torchrl.record.loggers.wandb import WandbLogger
try:
diff --git a/sota-implementations/grpo/config/grpo_gsm8k.yaml b/sota-implementations/grpo/config/grpo_gsm8k.yaml
index fb8d10e148b..3178f077f04 100644
--- a/sota-implementations/grpo/config/grpo_gsm8k.yaml
+++ b/sota-implementations/grpo/config/grpo_gsm8k.yaml
@@ -13,6 +13,14 @@ env:
num_envs: 8 # Reduced from 8 to save memory
# Number of times to repeat the same prompt for GRPO. This does not affect the GPU memory usage.
repeats: 16
+ # Whether to use the reasoning prompt
+ reasoning: false
+ # Maximum number of dialog turns per episode.
+ max_steps: 2
+ # Whether to group repeated samples together. Grouping will make all the answers to a single prompt to be written
+ # together, whereas a value of false will group multiple prompts in the buffer at a given time.
+ # Batches are usually bigger with group_repeats=false.
+ group_repeats: false
# Base model configuration
model:
@@ -37,13 +45,13 @@ train:
# Number of gradient accumulation steps. Higher values will use less GPU memory (comparing with bigger batches and lower gradient_accumulation_steps),
# but will make the optimization step slower.
- gradient_accumulation_steps: 1
+ gradient_accumulation_steps: 4
# Fields used by both scripts but with different semantics
checkpoint_frequency: 100 # Save checkpoint every N steps/batches
# Batch size for optimization. Higher values will use more GPU memory.
- optim_batch_size: 1
+ optim_batch_size: 4
# Whether to include the KL coefficient in the loss function. Alternatively, the KL ref-to-train will be added to the reward.
kl_coef_in_loss: true
@@ -56,6 +64,10 @@ train:
# Fields used only by grpo-async.py / grpo-sync.py
logging_frequency: 10 # Log metrics every N steps
+ # Whether to empty the replay buffer at the end of training epochs (sync only). Guarantees that data
+ # is used only once.
+ empty_replay_buffer: true
+
# Training model configuration
train_model:
gradient_checkpointing: true # Enabled for memory efficiency
diff --git a/sota-implementations/grpo/config/grpo_ifeval.yaml b/sota-implementations/grpo/config/grpo_ifeval.yaml
index 5916dc45168..bb772373e06 100644
--- a/sota-implementations/grpo/config/grpo_ifeval.yaml
+++ b/sota-implementations/grpo/config/grpo_ifeval.yaml
@@ -13,6 +13,10 @@ env:
num_envs: 4
# Number of times to repeat the same prompt for GRPO. This does not affect the GPU memory usage.
repeats: 16
+ # Whether to use the reasoning prompt
+ reasoning: false
+ # Maximum number of dialog turns per episode.
+ max_steps: 2
# Base model configuration
model:
@@ -43,19 +47,23 @@ train:
checkpoint_frequency: 100 # Save checkpoint every N steps/batches
# Batch size for optimization. Higher values will use more GPU memory.
- optim_batch_size: 2
+ optim_batch_size: 4
# Whether to include the KL coefficient in the loss function. Alternatively, the KL ref-to-train will be added to the reward.
- kl_coef_in_loss: false
+ kl_coef_in_loss: false
# KL coefficients for the KL divergence to the reference and inference policies
- kl_to_ref_coeff: 1e-1
- kl_to_inference_coeff: 1e-1
- entropy_coeff: 0.01
+ kl_to_ref_coeff: 1.0
+ kl_to_inference_coeff: 0.0
+ entropy_coeff: 0.001
# Fields used only by grpo-async.py / grpo-sync.py
logging_frequency: 1 # Log metrics every N steps - here at each optimization step
+ # Whether to empty the replay buffer at the end of training epochs (sync only). Guarantees that data
+ # is used only once.
+ empty_replay_buffer: true
+
# Training model configuration
train_model:
gradient_checkpointing: true # Enabled for memory efficiency
diff --git a/sota-implementations/grpo/config/mode/async.yaml b/sota-implementations/grpo/config/mode/async.yaml
index c72a0592849..8dff97800ad 100644
--- a/sota-implementations/grpo/config/mode/async.yaml
+++ b/sota-implementations/grpo/config/mode/async.yaml
@@ -5,7 +5,7 @@ train:
# Number of epochs to train for, every time a batch is collected. Per se, not directly used in async - aside from computing the total number of steps.
epochs: 1
- # The buffer size can be controlled in async mode
+ # The buffer size is overwritten in async mode.
buffer_size: 128
# Update policy weights every N steps - can be set to any positive integer in async mode
weight_update_frequency: 10
diff --git a/sota-implementations/grpo/config/mode/sync.yaml b/sota-implementations/grpo/config/mode/sync.yaml
index 743e850fcc2..8773a176728 100644
--- a/sota-implementations/grpo/config/mode/sync.yaml
+++ b/sota-implementations/grpo/config/mode/sync.yaml
@@ -5,7 +5,7 @@ train:
# Number of epochs to train for, every time a batch is collected.
epochs: 1
- # Leave buffer_size empty to use dialog_turns_per_batch in sync mode
+ # Override the buffer size in sync mode. If not set, the buffer size will be the number of repeats * num_envs
buffer_size:
# Update policy weights every N steps - must be left empty in sync mode
weight_update_frequency:
diff --git a/sota-implementations/grpo/grpo-async.py b/sota-implementations/grpo/grpo-async.py
index 04bdb6a95bd..a76691838da 100644
--- a/sota-implementations/grpo/grpo-async.py
+++ b/sota-implementations/grpo/grpo-async.py
@@ -15,7 +15,7 @@
from torchrl import torchrl_logger
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
-from torchrl.data.llm.chat import History
+from torchrl.data.llm.history import History
from torchrl.record.loggers.wandb import WandbLogger
try:
@@ -30,9 +30,9 @@
from grpo_utils import (
compute_device_allocation,
get_inference_model,
- get_ref_model,
- get_tokenizer,
get_train_model,
+ log_training_metrics,
+ make_env,
make_weight_updater,
)
from omegaconf import DictConfig
@@ -49,8 +49,6 @@
from torchrl.collectors.llm import RayLLMCollector
from torchrl.data import LazyStackStorage, ReplayBuffer
from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
-from torchrl.envs.llm import GSM8KEnv, KLRewardTransform
-from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
@@ -72,51 +70,6 @@ def setup_environment() -> None:
torch.cuda.set_device("cuda:0")
-def make_env(cfg: DictConfig, devices: list[int] | None = None):
- """Create the environment with proper device allocation.
-
- Args:
- cfg: The configuration object
-
- Returns:
- The configured environment
- """
- # Create reference model with proper device allocation
- # For the collector actor, we want inference_model devices first, then ref_model devices
- train_tokenizer = get_tokenizer(cfg)
-
- # Create a new config with adjusted device assignments
- ref_cfg = DictConfig(dict(cfg))
- ref_model = get_ref_model(ref_cfg, train_tokenizer, devices=devices)
-
- # Setup environment
- if cfg.env.dataset == "gsm8k":
- env = GSM8KEnv(
- repeats=cfg.env.repeats,
- tokenizer=train_tokenizer,
- num_envs=cfg.env.num_envs,
- )
- else: # ifeval
- env = IFEvalEnv(
- repeats=cfg.env.repeats,
- tokenizer=train_tokenizer,
- num_envs=cfg.env.num_envs,
- )
-
- # Pass device directly to KLRewardTransform - Since, for Ray, the local device is always 0
- # we can just use 0 here.
- device = torch.device("cuda:0")
- env = env.append_transform(
- KLRewardTransform(
- actor=ref_model,
- coef=cfg.train.kl_to_ref_coeff,
- add_to_reward=not cfg.train.kl_coef_in_loss,
- device=device,
- )
- )
- return env
-
-
def train(
replay_buffer: ReplayBuffer,
cfg: DictConfig,
@@ -145,8 +98,13 @@ def train(
kl_to_ref_coeff=cfg.train.kl_to_ref_coeff if cfg.train.kl_coef_in_loss else 0.0,
kl_to_inference_coeff=cfg.train.kl_to_inference_coeff,
entropy_coeff=cfg.train.entropy_coeff,
+ # use prompt/response masking for regular training, and assistant masking for reasoning
+ masking_strategy="rlhf" if cfg.env.reasoning else "sft",
device=train_device,
)
+ if cfg.env.reasoning:
+ # TODO: this is clunky, we should find a way to do this more naturally
+ loss_fn.set_keys(sample_log_prob=("next", "log_probs", "full"))
if cfg.model.compile:
loss_fn = torch.compile(loss_fn)
@@ -214,12 +172,14 @@ def train(
torchrl_logger.info(f"Total steps: {total_steps}")
pbar = tqdm.tqdm(total=total_steps)
- metrics = {} # Initialize metrics dict
grad_norm = 0.0 # Initialize grad_norm
data_read_count = 0
start_time = time.time()
for step in range(total_steps):
+ if not collector.is_running():
+ torchrl_logger.info("Collector stopped, stopping training")
+ break
pbar.update(1)
pbar.set_description(f"Step {step}, writes: {replay_buffer.write_count}")
@@ -228,7 +188,7 @@ def train(
batch = replay_buffer.sample(cfg.train.optim_batch_size).to(train_device)
# For logging purposes, we get the last element of the history
# and convert it to a string
- history: History = batch.view(-1)[0]["next", "history"]
+ history: History = batch.view(-1)[0]["history", "full"]
history_str: list[str] | str = history.apply_chat_template(
tokenizer=train_tokenizer
)
@@ -279,70 +239,19 @@ def train(
# Update metrics
if (step % cfg.train.logging_frequency) == 0:
- with torch.no_grad():
- rb_content = replay_buffer[:]
- batch_policy_version = batch["next", "policy_version"].view(-1).min()
- batch_policy_age = collector.policy_version - batch_policy_version
- metrics = {
- "reward from buffer": float(
- torch.cat(
- rb_content.get(("next", "reward"), as_list=True)
- ).mean()
- ),
- "kl_penalty (inference to ref) from buffer": float(
- torch.cat(
- rb_content.get(("next", "kl_penalty"), as_list=True)
- ).mean()
- ),
- "seq_length from buffer": float(
- torch.tensor(
- [
- t.numel()
- for t in rb_content.get("tokens_response", as_list=True)
- ],
- dtype=torch.float,
- ).mean()
- ),
- "ESS, from loss": float(loss.ESS),
- "loss_objective, from loss": float(loss.loss_objective),
- "clip_fraction, from loss": float(loss.clip_fraction),
- "kl_approx (train to inference), from loss": float(loss.kl_approx),
- "kl_to_inference (train to inference - differentiable), from loss": float(
- loss.kl_to_inference.mean()
- ),
- "kl_to_ref, from loss": float(loss.kl_to_ref.mean()),
- "loss_kl_to_inference, from loss": float(
- loss.loss_kl_to_inference.mean()
- ),
- "loss_kl_to_ref, from loss": float(loss.loss_kl_to_ref.mean()),
- "entropy loss, from loss": float(loss.loss_entropy.mean()),
- "grad_norm": float(grad_norm)
- if step % cfg.train.gradient_accumulation_steps == 0
- else metrics.get("grad_norm", 0.0),
- "write_count, from buffer": int(replay_buffer.write_count),
- # how many gradient steps per write
- "gradient_step_throughput (gradient step per write)": float(
- step / replay_buffer.write_count
- ),
- # how many optim steps per write
- "optim_step_throughput (optim step per write)": float(
- (step // cfg.train.gradient_accumulation_steps)
- / replay_buffer.write_count
- ),
- "data_read_count (total)": data_read_count,
- "current_policy_version (collector)": collector.policy_version,
- # FIXME: Assume batch is a single trajectory
- # FIXME: The addition of the transform after the env instantiation + _shuttle creation
- # is messed up - we need the next data
- "batch_policy_version (sampled batch)": batch_policy_version,
- "batch_policy_age (sampled batch)": batch_policy_age,
- "throughput (steps per second)": float(
- step / (time.time() - start_time)
- ),
- }
- for name, value in metrics.items():
- wandb_logger.log_scalar(name, value)
- wandb_logger.log_str("history", history_str, step=step)
+ log_training_metrics(
+ wandb_logger=wandb_logger,
+ replay_buffer=replay_buffer,
+ batch=batch,
+ loss=loss,
+ grad_norm=grad_norm,
+ global_step=step,
+ data_read_count=data_read_count,
+ collector=collector,
+ start_time=start_time,
+ gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
+ history_str=history_str,
+ )
# Update policy weights
if step % cfg.train.weight_update_frequency == 0:
@@ -437,20 +346,25 @@ def main(cfg):
train_handler_config = dict(cfg.ray.train_handler_config)
inference_policy = get_inference_model(
- cfg, devices=device_config["inference_model_devices"]
+ cfg,
+ devices=device_config["inference_model_devices"],
)
torchrl_logger.info(f"Inference policy: {inference_policy}")
torchrl_logger.info(f"Starting replay buffer with {replay_buffer_config=}")
+ if cfg.train.optim_batch_size % cfg.train.gradient_accumulation_steps != 0:
+ raise ValueError(
+ "optim_batch_size must be divisible by gradient_accumulation_steps"
+ )
rb = RayReplayBuffer(
storage=partial(
LazyStackStorage,
cfg.train.buffer_size
if cfg.train.buffer_size
- else cfg.train.dialog_turns_per_batch,
+ else cfg.env.repeats * cfg.env.num_envs,
),
transform_factory=partial(MCAdvantage, grpo_size=cfg.env.repeats),
- batch_size=cfg.train.optim_batch_size,
+ batch_size=cfg.train.optim_batch_size // cfg.train.gradient_accumulation_steps,
remote_config=replay_buffer_config,
)
torchrl_logger.info(f"Replay buffer: {rb}")
@@ -474,6 +388,8 @@ def main(cfg):
weight_updater=None, # We'll create this after getting the remote LLM
track_policy_version=True,
remote_config=collector_config,
+ yield_only_last_steps=cfg.env.reasoning,
+ verbose=False,
)
# Ensure collector is initialized by calling a method that will block until ready
ray.get(collector._collector.is_initialized.remote())
diff --git a/sota-implementations/grpo/grpo-sync.py b/sota-implementations/grpo/grpo-sync.py
index 24ca3f3a367..2b5363f4c91 100644
--- a/sota-implementations/grpo/grpo-sync.py
+++ b/sota-implementations/grpo/grpo-sync.py
@@ -14,7 +14,7 @@
from torchrl import torchrl_logger
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
-from torchrl.data.llm.chat import History
+from torchrl.data.llm.history import History
from torchrl.record.loggers.wandb import WandbLogger
try:
@@ -31,9 +31,9 @@
from grpo_utils import (
compute_device_allocation,
get_inference_model,
- get_ref_model,
- get_tokenizer,
get_train_model,
+ log_training_metrics,
+ make_env,
make_weight_updater,
)
from omegaconf import DictConfig
@@ -50,8 +50,6 @@
from torchrl.collectors.llm import RayLLMCollector
from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement
from torchrl.data.replay_buffers.ray_buffer import RayReplayBuffer
-from torchrl.envs.llm import GSM8KEnv, KLRewardTransform
-from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
@@ -73,51 +71,6 @@ def setup_environment() -> None:
torch.cuda.set_device("cuda:0")
-def make_env(cfg: DictConfig, devices: list[int] | None = None):
- """Create the environment with proper device allocation.
-
- Args:
- cfg: The configuration object
-
- Returns:
- The configured environment
- """
- # Create reference model with proper device allocation
- # For the collector actor, we want inference_model devices first, then ref_model devices
- train_tokenizer = get_tokenizer(cfg)
-
- # Create a new config with adjusted device assignments
- ref_cfg = DictConfig(dict(cfg))
- ref_model = get_ref_model(ref_cfg, train_tokenizer, devices=devices)
-
- # Setup environment
- if cfg.env.dataset == "gsm8k":
- env = GSM8KEnv(
- repeats=cfg.env.repeats,
- tokenizer=train_tokenizer,
- num_envs=cfg.env.num_envs,
- )
- else: # ifeval
- env = IFEvalEnv(
- repeats=cfg.env.repeats,
- tokenizer=train_tokenizer,
- num_envs=cfg.env.num_envs,
- )
-
- # Pass device directly to KLRewardTransform - Since, for Ray, the local device is always 0
- # we can just use 0 here.
- device = torch.device("cuda:0")
- env = env.append_transform(
- KLRewardTransform(
- actor=ref_model,
- coef=cfg.train.kl_to_ref_coeff,
- add_to_reward=not cfg.train.kl_coef_in_loss,
- device=device,
- )
- )
- return env
-
-
def train(
replay_buffer: ReplayBuffer,
cfg: DictConfig,
@@ -146,8 +99,13 @@ def train(
kl_to_ref_coeff=cfg.train.kl_to_ref_coeff if cfg.train.kl_coef_in_loss else 0.0,
kl_to_inference_coeff=cfg.train.kl_to_inference_coeff,
entropy_coeff=cfg.train.entropy_coeff,
+ # use prompt/response masking for regular training, and assistant masking for reasoning
+ masking_strategy="rlhf" if cfg.env.reasoning else "sft",
device=train_device,
)
+ if cfg.env.reasoning:
+ # TODO: this is clunky, we should find a way to do this more naturally
+ loss_fn.set_keys(sample_log_prob=("next", "log_probs", "full"))
if cfg.model.compile:
loss_fn = torch.compile(loss_fn)
@@ -205,13 +163,20 @@ def train(
# Training loop
torchrl_logger.info("Starting training loop.")
pbar = tqdm.tqdm(collector)
- metrics = {} # Initialize metrics dict
grad_norm = 0.0 # Initialize grad_norm
data_read_count = 0
global_step = 0
start_time = time.time()
for data in pbar:
+ # Wait for the replay buffer to be filled - when reasoning, we collect trajectories
+ # so the buffer may not be filled straight away
+ if not len(replay_buffer):
+ torchrl_logger.info("Waiting for replay buffer to be filled")
+ continue
+ else:
+ torchrl_logger.info(f"Replay buffer filled: {len(replay_buffer)}")
+
pbar.update(1)
# data is None as the collector directly writes to the replay buffer
@@ -228,7 +193,7 @@ def train(
)
# For logging purposes, we get the last element of the history
# and convert it to a string
- history: History = batch.view(-1)[0]["next", "history"]
+ history: History = batch.view(-1)[0]["next", "history"].prompt
history_str: list[str] | str = history.apply_chat_template(
tokenizer=train_tokenizer
)
@@ -289,80 +254,19 @@ def train(
# Update metrics
if (global_step % cfg.train.logging_frequency) == 0:
- with torch.no_grad():
- rb_content = replay_buffer[:]
- batch_policy_version = (
- batch["next", "policy_version"].view(-1).min()
- )
- batch_policy_age = (
- collector.policy_version - batch_policy_version
- )
- metrics = {
- "reward from buffer": float(
- torch.cat(
- rb_content.get(("next", "reward"), as_list=True)
- ).mean()
- ),
- "kl_penalty (inference to ref) from buffer": float(
- torch.cat(
- rb_content.get(("next", "kl_penalty"), as_list=True)
- ).mean()
- ),
- "seq_length from buffer": float(
- torch.tensor(
- [
- t.numel()
- for t in rb_content.get(
- "tokens_response", as_list=True
- )
- ],
- dtype=torch.float,
- ).mean()
- ),
- "ESS, from loss": float(loss.ESS),
- "loss_objective, from loss": float(loss.loss_objective),
- "clip_fraction, from loss": float(loss.clip_fraction),
- "kl_approx (train to inference), from loss": float(
- loss.kl_approx
- ),
- "kl_to_inference (train to inference - differentiable), from loss": float(
- loss.kl_to_inference.mean()
- ),
- "kl_to_ref, from loss": float(loss.kl_to_ref.mean()),
- "loss_kl_to_inference, from loss": float(
- loss.loss_kl_to_inference.mean()
- ),
- "loss_kl_to_ref, from loss": float(
- loss.loss_kl_to_ref.mean()
- ),
- "entropy loss, from loss": float(loss.loss_entropy.mean()),
- "grad_norm": float(grad_norm)
- if global_step % cfg.train.gradient_accumulation_steps == 0
- else metrics.get("grad_norm", 0.0),
- "write_count, from buffer": int(replay_buffer.write_count),
- # how many gradient steps per write
- "gradient_step_throughput (gradient step per write)": float(
- global_step / replay_buffer.write_count
- ),
- # how many optim steps per write
- "optim_step_throughput (optim step per write)": float(
- (global_step // cfg.train.gradient_accumulation_steps)
- / replay_buffer.write_count
- ),
- "data_read_count (total)": data_read_count,
- "current_policy_version (collector)": collector.policy_version,
- # FIXME: Assume batch is a single trajectory
- # FIXME: The addition of the transform after the env instantiation + _shuttle creation
- # is messed up - we need the next data
- "batch_policy_version (sampled batch)": batch_policy_version,
- "batch_policy_age (sampled batch)": batch_policy_age,
- "throughput (steps per second)": float(
- global_step / (time.time() - start_time)
- ),
- }
- for name, value in metrics.items():
- wandb_logger.log_scalar(name, value)
- wandb_logger.log_str("history", history_str, step=global_step)
+ log_training_metrics(
+ wandb_logger=wandb_logger,
+ replay_buffer=replay_buffer,
+ batch=batch,
+ loss=loss,
+ grad_norm=grad_norm,
+ global_step=global_step,
+ data_read_count=data_read_count,
+ collector=collector,
+ start_time=start_time,
+ gradient_accumulation_steps=cfg.train.gradient_accumulation_steps,
+ history_str=history_str,
+ )
# Checkpointing disabled to prevent disk space issues
# if (global_step + 1) % cfg.train.checkpoint_frequency == 0:
@@ -392,6 +296,9 @@ def train(
wandb_logger.log_scalar(f"timeit/{key}", val)
timeit.reset()
+ if cfg.train.empty_replay_buffer:
+ replay_buffer.empty(empty_write_count=False)
+
pbar.close()
collector.shutdown()
@@ -463,14 +370,23 @@ def main(cfg):
"buffer_size must be equal to dialog_turns_per_batch in sync settings."
)
+ if cfg.train.optim_batch_size % cfg.train.gradient_accumulation_steps != 0:
+ raise ValueError(
+ "optim_batch_size must be divisible by gradient_accumulation_steps"
+ )
+
rb = RayReplayBuffer(
storage=partial(
LazyStackStorage,
- cfg.train.dialog_turns_per_batch,
+ # Since we cache the values in the queue until we have "repeats" samples,
+ # the buffer can be bigger than what the dialog_turns_per_batch (at most repeats * num_envs)
+ cfg.train.buffer_size
+ if cfg.train.buffer_size
+ else cfg.env.repeats * cfg.env.num_envs,
),
sampler=SamplerWithoutReplacement,
- transform_factory=partial(MCAdvantage, grpo_size=cfg.env.repeats),
- batch_size=cfg.train.optim_batch_size,
+ transform_factory=partial(MCAdvantage, grpo_size=cfg.env.repeats, verbose=True),
+ batch_size=cfg.train.optim_batch_size // cfg.train.gradient_accumulation_steps,
remote_config=replay_buffer_config,
)
torchrl_logger.info(f"Replay buffer: {rb}")
@@ -494,7 +410,8 @@ def main(cfg):
track_policy_version=True,
remote_config=collector_config,
sync_iter=cfg.train.sync_iter,
- verbose=True,
+ verbose=False,
+ yield_only_last_steps=cfg.env.reasoning,
)
# Ensure collector is initialized by calling a method that will block until ready
ray.get(collector._collector.is_initialized.remote())
diff --git a/sota-implementations/grpo/grpo_utils.py b/sota-implementations/grpo/grpo_utils.py
index 6a99dde7cf0..fe73d9ca8c1 100644
--- a/sota-implementations/grpo/grpo_utils.py
+++ b/sota-implementations/grpo/grpo_utils.py
@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
+import time
from typing import Any, Callable, Literal
import torch
@@ -12,6 +13,8 @@
from torchrl._utils import logger as torchrl_logger
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
+from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL
+from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
from transformers.tokenization_utils import PreTrainedTokenizer
@@ -93,9 +96,11 @@ def get_train_model(
policy_training = TransformersWrapper(
train_model,
tokenizer=train_tokenizer,
- from_text=False,
+ input_mode="tokens" if not cfg.env.reasoning else "history",
generate=False,
return_log_probs=True,
+ pad_output=False,
+ device=torch.device("cuda:0"),
)
# Ensure model stays in eval mode after wrapping
policy_training.model.eval()
@@ -104,7 +109,10 @@ def get_train_model(
def get_inference_model(
- cfg: DictConfig, devices: list[int] | None = None, make_ray_worker: bool = True
+ cfg: DictConfig,
+ devices: list[int] | None = None,
+ make_ray_worker: bool = True,
+ tokenizer: PreTrainedTokenizer | None = None,
) -> vLLMWrapper:
"""Creates the vLLM-based inference model for fast generation.
@@ -116,7 +124,9 @@ def get_inference_model(
cfg (DictConfig): The hydra configuration object containing model settings.
Expected to have inference_model section with vLLM-specific parameters
like gpu_memory_utilization and generation settings.
- make_ray_worker (bool, optional): Whether to make a ray worker. Default: True
+ devices (list[int], optional): The devices to use for the inference model. Default: `None`.
+ make_ray_worker (bool, optional): Whether to make a ray worker. Default: `True`.
+ tokenizer (PreTrainedTokenizer, optional): The tokenizer to use with the inference model. Default: `None`.
Returns:
vLLMWrapper: The wrapped vLLM model ready for inference.
@@ -149,10 +159,20 @@ def get_inference_model(
enforce_eager=cfg.inference_model.enforce_eager,
)
assert inference_server is not None
+ if tokenizer is None:
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+ if tokenizer.pad_token == tokenizer.eos_token:
+ tokenizer.pad_token = "PAD"
+ tokenizer.padding_side = "left"
policy = vLLMWrapper(
inference_server,
- from_text=True,
- return_log_probs=True,
+ input_mode="history",
+ chat_template_name="qwen",
+ return_log_probs=not cfg.env.reasoning,
+ tokenizer=tokenizer,
+ pad_output=False,
generate_kwargs={
"max_tokens": cfg.inference_model.max_tokens,
"include_stop_str_in_output": cfg.inference_model.include_stop_str_in_output,
@@ -164,7 +184,9 @@ def get_inference_model(
def get_ref_model(
- cfg: DictConfig, tokenizer: PreTrainedTokenizer, devices: list[int] | None = None
+ cfg: DictConfig,
+ tokenizer: PreTrainedTokenizer,
+ devices: list[int] | None = None,
) -> TransformersWrapper:
"""Creates the reference model for KL penalty computation.
@@ -218,10 +240,12 @@ def get_ref_model(
TensorDict.from_module(ref_model).data.to_module(ref_model)
ref_model = TransformersWrapper(
ref_model,
+ input_mode="tokens" if not cfg.env.reasoning else "history",
tokenizer=tokenizer,
- from_text=False,
generate=False,
return_log_probs=True,
+ pad_output=False,
+ device=torch.device("cuda:0"),
)
return ref_model
@@ -473,3 +497,183 @@ def compute_device_allocation(cfg):
"ray_num_gpus": ray_num_gpus,
"cuda_visible_devices": cuda_visible_devices,
}
+
+
+def make_env(cfg: DictConfig, devices: list[int] | None = None):
+ """Create the environment with proper device allocation.
+
+ Args:
+ cfg: The configuration object
+
+ Returns:
+ The configured environment
+ """
+ # Create reference model with proper device allocation
+ # For the collector actor, we want inference_model devices first, then ref_model devices
+ train_tokenizer = get_tokenizer(cfg)
+
+ # Create a new config with adjusted device assignments
+ ref_cfg = DictConfig(dict(cfg))
+ ref_model = get_ref_model(
+ ref_cfg,
+ train_tokenizer,
+ devices=devices,
+ )
+
+ # Setup environment
+ max_steps = cfg.env.max_steps if cfg.env.reasoning else 1
+ if cfg.env.dataset == "gsm8k":
+ # Reward scale is 0.0 to 100
+ reward_threshold = 20
+ env = GSM8KEnv(
+ repeats=cfg.env.repeats,
+ tokenizer=train_tokenizer,
+ num_envs=cfg.env.num_envs,
+ max_steps=max_steps,
+ group_repeats=cfg.env.group_repeats,
+ device=torch.device("cuda:0") if devices is not None else None,
+ )
+ elif cfg.env.dataset == "ifeval": # ifeval
+ # Reward scale is 0.0 to 2.2
+ reward_threshold = 1.0
+ env = IFEvalEnv(
+ repeats=cfg.env.repeats,
+ tokenizer=train_tokenizer,
+ num_envs=cfg.env.num_envs,
+ max_steps=max_steps,
+ group_repeats=cfg.env.group_repeats,
+ device=torch.device("cuda:0") if devices is not None else None,
+ )
+ else:
+ raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented")
+ if cfg.env.reasoning:
+ env = env.append_transform(
+ AddThinkingPrompt(
+ cond=lambda td, reward_threshol=reward_threshold, max_steps=max_steps: td[
+ "reward"
+ ]
+ <= reward_threshold
+ and td["step_count"] < max_steps,
+ role="assistant",
+ edit_last_turn=False,
+ zero_reward=False,
+ undo_done=True,
+ random_prompt=True,
+ ),
+ )
+ env = env.append_transform(
+ # RetrieveKL will be lazily initialized in the collector.
+ # We use RetrieveKL instead of KLRewardTransform because the assistant response may change when
+ # adding the thinking prompt, requiring a second pass in vllm to compute the log-probs.
+ RetrieveKL(
+ ref_model=ref_model,
+ add_to_reward=not cfg.train.kl_coef_in_loss,
+ coeff=cfg.train.kl_to_ref_coeff,
+ )
+ )
+ else:
+ # Pass device directly to KLRewardTransform - Since, for Ray, the local device is always 0
+ # we can just use 0 here.
+ device = torch.device("cuda:0")
+ env = env.append_transform(
+ KLRewardTransform(
+ ref_model=ref_model,
+ coef=cfg.train.kl_to_ref_coeff,
+ add_to_reward=not cfg.train.kl_coef_in_loss,
+ device=device,
+ )
+ )
+ return env
+
+
+def log_training_metrics(
+ wandb_logger,
+ replay_buffer,
+ batch,
+ loss,
+ grad_norm,
+ global_step,
+ data_read_count,
+ collector,
+ start_time,
+ gradient_accumulation_steps,
+ history_str=None,
+):
+ """Log training metrics to wandb.
+
+ Args:
+ wandb_logger: The wandb logger instance
+ replay_buffer: The replay buffer containing collected data
+ batch: The current training batch
+ loss: The computed loss object
+ grad_norm: The gradient norm value
+ global_step: Current global training step
+ data_read_count: Total data read count
+ collector: The collector instance
+ start_time: Training start time
+ gradient_accumulation_steps: Number of gradient accumulation steps
+ history_str: Optional history string for logging
+ """
+ with torch.no_grad():
+ rb_content = replay_buffer[:]
+ step_count = rb_content.get(("next", "step_count")).view(-1).float().mean()
+ batch_policy_version = batch["next", "policy_version"].view(-1).min()
+ batch_policy_age = collector.policy_version - batch_policy_version
+
+ metrics = {
+ "step_count from buffer": float(step_count),
+ "reward from buffer": float(
+ torch.cat(rb_content.get(("next", "reward"), as_list=True)).mean()
+ ),
+ "kl_penalty (inference to ref) from buffer": float(
+ torch.cat(rb_content.get(("next", "kl_penalty"), as_list=True)).mean()
+ ),
+ "seq_length from buffer": float(
+ torch.tensor(
+ [
+ t.numel()
+ for t in rb_content.get(("tokens", "response"), as_list=True)
+ ],
+ dtype=torch.float,
+ ).mean()
+ ),
+ "ESS, from loss": float(loss.ESS),
+ "loss_objective, from loss": float(loss.loss_objective),
+ "clip_fraction, from loss": float(loss.clip_fraction),
+ "kl_approx (train to inference), from loss": float(loss.kl_approx),
+ "kl_to_inference (train to inference - differentiable), from loss": float(
+ loss.kl_to_inference.mean()
+ ),
+ "kl_to_ref, from loss": float(loss.kl_to_ref.mean()),
+ "loss_kl_to_inference, from loss": float(loss.loss_kl_to_inference.mean()),
+ "loss_kl_to_ref, from loss": float(loss.loss_kl_to_ref.mean()),
+ "entropy loss, from loss": float(loss.loss_entropy.mean()),
+ "grad_norm": float(grad_norm)
+ if global_step % gradient_accumulation_steps == 0
+ else 0.0,
+ "write_count, from buffer": int(replay_buffer.write_count),
+ # how many gradient steps per write
+ "gradient_step_throughput (gradient step per write)": float(
+ global_step / replay_buffer.write_count
+ ),
+ # how many optim steps per write
+ "optim_step_throughput (optim step per write)": float(
+ (global_step // gradient_accumulation_steps) / replay_buffer.write_count
+ ),
+ "data_read_count (total)": data_read_count,
+ "current_policy_version (collector)": collector.policy_version,
+ # FIXME: Assume batch is a single trajectory
+ # FIXME: The addition of the transform after the env instantiation + _shuttle creation
+ # is messed up - we need the next data
+ "batch_policy_version (sampled batch)": batch_policy_version,
+ "batch_policy_age (sampled batch)": batch_policy_age,
+ "throughput (steps per second)": float(
+ global_step / (time.time() - start_time)
+ ),
+ }
+
+ for name, value in metrics.items():
+ wandb_logger.log_scalar(name, value, step=global_step)
+
+ if history_str is not None:
+ wandb_logger.log_str("history", history_str, step=global_step)
diff --git a/test/llm/mocking_classes_llm.py b/test/llm/mocking_classes_llm.py
index 60ddc4b8dd9..59865bb3232 100644
--- a/test/llm/mocking_classes_llm.py
+++ b/test/llm/mocking_classes_llm.py
@@ -28,7 +28,7 @@ def __next__(self):
return {"text": self.generate_random_string()}
else:
return {
- "text": [self.generate_random_string() for _ in range(self.batch_size)]
+ "query": [self.generate_random_string() for _ in range(self.batch_size)]
}
diff --git a/test/llm/test_collectors.py b/test/llm/test_collectors.py
index 8e4047b34f0..01318eecbf5 100644
--- a/test/llm/test_collectors.py
+++ b/test/llm/test_collectors.py
@@ -12,20 +12,26 @@
import pytest
import torch
from mocking_classes_llm import DummyStrDataLoader
+from tensordict import set_list_to_stack
from torchrl import logger as torchrl_logger
from torchrl.collectors.llm import LLMCollector
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
from torchrl.data import LazyStackStorage, ReplayBuffer
from torchrl.envs import AsyncEnvPool, StepCounter
-from torchrl.envs.llm import LLMEnv
+from torchrl.envs.llm.chat import ChatEnv
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
-
from torchrl.modules.llm.backends.vllm import make_vllm_worker
_has_transformers = importlib.util.find_spec("transformers") is not None
_has_vllm = importlib.util.find_spec("vllm") is not None
+@pytest.fixture(scope="module", autouse=True)
+def set_list_to_stack_fixture():
+ with set_list_to_stack(True):
+ yield
+
+
@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
class TestLLMCollector:
@@ -90,7 +96,7 @@ def test_llm_collector_with_transformers(
policy = TransformersWrapper(
model,
tokenizer=tokenizer,
- from_text=True,
+ input_mode="history",
generate=True,
return_log_probs=True,
)
@@ -100,12 +106,11 @@ def _run_collector_test(self, total_steps, rb, queue, policy, tokenizer):
bsz = 4
dataloader = DummyStrDataLoader(bsz)
- env = LLMEnv.from_dataloader(
+ env = ChatEnv.from_dataloader(
dataloader=dataloader,
- from_text=True,
batch_size=bsz,
group_repeats=True,
- eos_token_id=tokenizer.eos_token_id,
+ input_mode="history",
)
queue = None
if rb:
@@ -138,13 +143,15 @@ def _run_collector_test(self, total_steps, rb, queue, policy, tokenizer):
assert sample.shape == (4,)
assert not sample._has_exclusive_keys
# Should match length
- assert len(sample["text"]) == 4
+ assert len(sample["text", "prompt"]) == 4
# assert len(sample["text"][0]) == 10, sample["text"][0]
# Should be non-empty
- assert sample["text_response"] is not None
+ assert sample["text", "response"] is not None
for i in range(4):
# Check that there are more chars in the next step
- assert len(sample["text"][i]) < len(sample["next", "text"][i])
+ assert len(sample["history", "prompt"][i]) < len(
+ sample["next", "history", "prompt"][i]
+ )
else:
stack = torch.cat(stack)
assert not stack._has_exclusive_keys
@@ -152,7 +159,9 @@ def _run_collector_test(self, total_steps, rb, queue, policy, tokenizer):
stack = stack.view(-1)
for i in range(stack.numel()):
# Check that there are more chars in the next step
- assert len(stack["text"][i]) < len(stack["next", "text"][i])
+ assert len(stack["history", "prompt"][i]) < len(
+ stack["next", "history", "prompt"][i]
+ )
assert collector._frames >= total_steps
@pytest.mark.slow
@@ -164,11 +173,11 @@ def test_llm_collector_start(self, vllm_instance):
bsz = 4
dataloader = DummyStrDataLoader(bsz)
- env = LLMEnv.from_dataloader(
+ env = ChatEnv.from_dataloader(
dataloader=dataloader,
- from_text=True,
batch_size=bsz,
group_repeats=True,
+ input_mode="history",
)
rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2))
@@ -191,7 +200,9 @@ def test_llm_collector_start(self, vllm_instance):
assert sample.ndim == 1
for i in range(10):
# Check that there are more chars in the next step
- assert len(sample["text"][i]) < len(sample["next", "text"][i])
+ assert len(sample["history", "prompt"][i]) < len(
+ sample["next", "history", "prompt"][i]
+ )
assert not sample._has_exclusive_keys, sample
j += 1
if rb.write_count >= total_steps:
@@ -201,25 +212,33 @@ def test_llm_collector_start(self, vllm_instance):
collector.async_shutdown(timeout=10)
@pytest.mark.slow
- @pytest.mark.parametrize("rb", [False, True])
- @pytest.mark.parametrize("yield_only_last_steps", [False, True])
+ @pytest.mark.parametrize("rb", [False, True], ids=["rb_false", "rb_true"])
+ @pytest.mark.parametrize(
+ "yield_only_last_steps",
+ [False, True],
+ ids=["yield_only_last_steps_false", "yield_only_last_steps_true"],
+ )
+ @pytest.mark.parametrize(
+ "dialog_turns_per_batch",
+ [4, None],
+ ids=["dialog_turns_per_batch_4", "dialog_turns_per_batch_none"],
+ )
def test_llm_collector_completed(
- self, vllm_instance_opt, rb, yield_only_last_steps
+ self, vllm_instance_opt, rb, yield_only_last_steps, dialog_turns_per_batch
):
torch.manual_seed(0)
policy = vLLMWrapper(vllm_instance_opt)
- tokenizer = vllm_instance_opt.get_tokenizer()
+ vllm_instance_opt.get_tokenizer()
bsz = 4
total_steps = 20
max_steps = 20
dataloader = DummyStrDataLoader(bsz)
- env = LLMEnv.from_dataloader(
+ env = ChatEnv.from_dataloader(
dataloader=dataloader,
- from_text=True,
+ input_mode="history",
batch_size=bsz,
group_repeats=True,
- eos_token_id=tokenizer.eos_token_id,
)
# To make sure the env breaks at some point
env = env.append_transform(StepCounter(max_steps=max_steps))
@@ -228,15 +247,23 @@ def test_llm_collector_completed(
rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2))
else:
rb = None
+
+ kwargs = (
+ {"dialog_turns_per_batch": dialog_turns_per_batch}
+ if dialog_turns_per_batch is not None
+ else {}
+ )
collector = LLMCollector(
env=env,
policy_factory=lambda: policy,
- dialog_turns_per_batch=env.batch_size[0],
replay_buffer=rb,
total_dialog_turns=total_steps,
yield_completed_trajectories=True,
yield_only_last_steps=yield_only_last_steps,
+ **kwargs,
)
+ if not dialog_turns_per_batch:
+ assert collector.dialog_turns_per_batch == 1
assert collector.yield_completed_trajectories
assert collector.yield_only_last_steps is yield_only_last_steps
@@ -250,51 +277,43 @@ def test_llm_collector_completed(
for i in range(data.numel()):
if data[i]["next", "step_count"] == max_steps:
continue
- if data[i]["text_response"]:
- # Check that there are more chars in the next step
- assert len(data["text"][i]) < len(data["next", "text"][i]), (
- i,
- data[i]["next", "step_count"],
- data[i]["next", "done"],
- data[i]["text_response"],
- )
- else:
- assert len(data["text"][i]) == len(data["next", "text"][i]), (
- i,
- data[i]["next", "step_count"],
- data[i]["next", "done"],
- data[i]["text_response"],
- )
-
+ # Check that there are more chars in the next step
+ assert len(data["history", "prompt"][i]) < len(
+ data["next", "history", "prompt"][i]
+ ), (
+ i,
+ data[i]["next", "step_count"],
+ data[i]["next", "done"],
+ data[i]["text_response"],
+ )
+
+ expected_shape = (
+ collector.dialog_turns_per_batch
+ if collector.dialog_turns_per_batch
+ else 1
+ )
+ # since we want only completed trajs, either we have all the steps (and hence the number of elements is
+ # bigger than dialog_turns_per_batch) or we have all the last steps in number strictly equal to dialog_turns_per_batch
if yield_only_last_steps:
- assert data.shape == (1,)
+ assert data.numel() == expected_shape, (data.shape, expected_shape)
else:
- has_found_one_with_more_steps |= data.numel() > 1
+ assert data.numel() >= expected_shape, (data.shape, expected_shape)
+ has_found_one_with_more_steps |= data.numel() > 1
else:
assert data is None
sample = rb.sample(5)
for i in range(sample.numel()):
if sample[i]["next", "step_count"] == max_steps:
continue
- if sample[i]["text_response"]:
- # Check that there are more chars in the next step
- assert len(sample["text"][i]) < len(
- sample["next", "text"][i]
- ), (
- i,
- sample[i]["next", "step_count"],
- sample[i]["next", "done"],
- sample[i]["text_response"],
- )
- else:
- assert len(sample["text"][i]) == len(
- sample["next", "text"][i]
- ), (
- i,
- sample[i]["next", "step_count"],
- sample[i]["next", "done"],
- sample[i]["text_response"],
- )
+ # Check that there are more chars in the next step
+ assert len(sample["history", "prompt"][i]) < len(
+ sample["next", "history", "prompt"][i]
+ ), (
+ i,
+ sample[i]["next", "step_count"],
+ sample[i]["next", "done"],
+ sample[i]["text_response"],
+ )
assert sample.ndim == 1
assert sample.shape == (5,)
@@ -313,19 +332,18 @@ def test_llm_collector_completed_async(
):
torch.manual_seed(0)
policy = vLLMWrapper(vllm_instance_opt)
- tokenizer = vllm_instance_opt.get_tokenizer()
+ vllm_instance_opt.get_tokenizer()
bsz = 4
total_steps = 20
max_steps = 20
dataloader = DummyStrDataLoader(bsz)
def env_maker():
- env = LLMEnv.from_dataloader(
+ env = ChatEnv.from_dataloader(
dataloader=dataloader,
from_text=True,
batch_size=(),
group_repeats=True,
- eos_token_id=tokenizer.eos_token_id,
)
# To make sure the env breaks at some point
env = env.append_transform(StepCounter(max_steps=max_steps))
@@ -359,21 +377,15 @@ def env_maker():
for i in range(data.numel()):
if data[i]["next", "step_count"] == max_steps:
continue
- if data[i]["text_response"]:
- # Check that there are more chars in the next step
- assert len(data["text"][i]) < len(data["next", "text"][i]), (
- i,
- data[i]["next", "step_count"],
- data[i]["next", "done"],
- data[i]["text_response"],
- )
- else:
- assert len(data["text"][i]) == len(data["next", "text"][i]), (
- i,
- data[i]["next", "step_count"],
- data[i]["next", "done"],
- data[i]["text_response"],
- )
+ # Check that there are more chars in the next step
+ assert len(data["history", "prompt"][i]) < len(
+ data["next", "history", "prompt"][i]
+ ), (
+ i,
+ data[i]["next", "step_count"],
+ data[i]["next", "done"],
+ data[i]["text_response"],
+ )
if yield_only_last_steps:
assert data.shape == (1,)
@@ -385,25 +397,15 @@ def env_maker():
for i in range(sample.numel()):
if sample[i]["next", "step_count"] == max_steps:
continue
- if sample[i]["text_response"]:
- # Check that there are more chars in the next step
- assert len(sample["text"][i]) < len(
- sample["next", "text"][i]
- ), (
- i,
- sample[i]["next", "step_count"],
- sample[i]["next", "done"],
- sample[i]["text_response"],
- )
- else:
- assert len(sample["text"][i]) == len(
- sample["next", "text"][i]
- ), (
- i,
- sample[i]["next", "step_count"],
- sample[i]["next", "done"],
- sample[i]["text_response"],
- )
+ # Check that there are more chars in the next step
+ assert len(sample["history", "prompt"][i]) < len(
+ sample["next", "history", "prompt"][i]
+ ), (
+ i,
+ sample[i]["next", "step_count"],
+ sample[i]["next", "done"],
+ sample[i]["text_response"],
+ )
assert sample.ndim == 1
assert sample.shape == (5,)
diff --git a/test/llm/test_data.py b/test/llm/test_data.py
index 082bdc1bc16..9e599b5c6f2 100644
--- a/test/llm/test_data.py
+++ b/test/llm/test_data.py
@@ -21,7 +21,7 @@
ReplayBuffer,
SamplerWithoutReplacement,
)
-from torchrl.data.llm.chat import ContentBase
+from torchrl.data.llm.history import ContentBase
from torchrl.data.llm.topk import TopKRewardSelector
_has_transformers = importlib.util.find_spec("transformers")
@@ -324,8 +324,12 @@ def test_content_base(self):
The result is""",
]
- @pytest.mark.parametrize("test_case", TEST_CASES)
- def test_history_assistant_mask(self, test_case):
+ @pytest.mark.parametrize(
+ "test_case",
+ TEST_CASES,
+ ids=["case_1", "case_2", "case_3", "case_4", "case_5", "case_6"],
+ )
+ def test_history_assistant_mask_qwen(self, test_case):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
@@ -355,13 +359,108 @@ def test_history_assistant_mask(self, test_case):
assert type(decoded) is str
assert last_item.endswith(decoded), (decoded, last_item)
+ LLAMA_TEST_CASES = [
+ # Case 1: All messages complete
+ """<|begin_of_text|><|header_start|>system<|header_end|>
+
+I am a helpful assistant.<|eot|><|header_start|>user<|header_end|>
+
+What is the capital of France?<|eot|><|header_start|>assistant<|header_end|>
+
+The capital of France is Paris.<|eot|>""",
+ # Case 2: Last message incomplete
+ """<|begin_of_text|><|header_start|>system<|header_end|>
+
+I am a helpful assistant.<|eot|><|header_start|>user<|header_end|>
+
+What is the capital of France?<|eot|><|header_start|>assistant<|header_end|>
+
+The capital of France is""",
+ # Case 3: Multiple messages with mix of endings
+ """<|begin_of_text|><|header_start|>system<|header_end|>
+
+I am a helpful assistant.<|eot|><|header_start|>user<|header_end|>
+
+Tell me about Python.<|eot|><|header_start|>assistant<|header_end|>
+
+Python is a programming language.<|eot|><|header_start|>user<|header_end|>
+
+Can you elaborate?<|eot|><|header_start|>assistant<|header_end|>
+
+Python is known for its simplicity""",
+ # Case 4: Single incomplete message
+ """<|header_start|>assistant<|header_end|>
+
+Let me help you with that""",
+ # # Case 5: Empty content but complete -- not supported by LLAMA 4
+ # """<|begin_of_text|><|header_start|>system<|header_end|>
+ # <|eot|><|header_start|>user<|header_end|>
+ # <|eot|>""",
+ # Case 6: Message with tool calls
+ """<|begin_of_text|><|header_start|>system<|header_end|>
+
+I am an assistant that can use tools.<|eot|><|header_start|>user<|header_end|>
+
+<|eot|><|header_start|>assistant<|header_end|>
+
+Let me help you with that.
+
+{"name": "calculator", "arguments": {"expression": "2+2"}}
+ <|eot|><|header_start|>user<|header_end|>
+
+4<|eot|><|header_start|>assistant<|header_end|>
+
+The result is""",
+ ]
+
+ @pytest.mark.parametrize(
+ "test_case",
+ LLAMA_TEST_CASES,
+ ids=["case_1", "case_2", "case_3", "case_4", "case_6"],
+ )
+ def test_history_assistant_mask_llama(self, test_case):
+ from transformers import AutoTokenizer
+
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(
+ "meta-llama/Llama-4-Scout-17B-16E-Instruct"
+ )
+ except Exception:
+ pytest.skip("Could not load Llama tokenizer")
+
+ history = History.from_text(test_case, chat_template_name="llama")
+ proc = history.apply_chat_template(
+ tokenizer=tokenizer,
+ chat_template_name="llama",
+ add_generation_prompt=False,
+ return_dict=True,
+ return_assistant_tokens_mask=True,
+ )
+ role_assistant = torch.tensor([r == "assistant" for r in history.role])
+ last_item: str = history[role_assistant].apply_chat_template(
+ tokenizer=tokenizer,
+ chat_template_name="llama",
+ add_generation_prompt=False,
+ )
+
+ if "assistant" in history.role:
+ assert proc["assistant_masks"].any()
+ else:
+ assert not proc["assistant_masks"].any()
+ if last_item:
+ decoded = tokenizer.decode(
+ proc["input_ids"][proc["assistant_masks"].bool()]
+ )
+ assert type(decoded) is str
+ assert last_item.endswith(decoded), (decoded, last_item)
+
def test_history_completion(self):
"""Test the History class's handling of complete and incomplete messages."""
for i, test_case in enumerate(self.TEST_CASES):
history = History.from_text(test_case, chat_template_name="qwen")
- # Print details about each message
+ # torchrl_logger.info details about each message
for j, (role, content, is_complete) in enumerate(
zip(history.role, history.content, history.is_complete)
):
@@ -418,6 +517,455 @@ def test_history_completion(self):
], "Case 5 should have last message incomplete"
assert history[2].role == "tool"
+ @pytest.mark.parametrize(
+ "model_name, expected_template",
+ [
+ ("Qwen/Qwen2.5-0.5B", "qwen"),
+ ("microsoft/phi-2", "chatml_format"),
+ ("mosaicml/mpt-7b-instruct", "chatml_format"),
+ ("facebook/opt-125m", "chatml_format"),
+ ("gpt2", "chatml_format"),
+ ("EleutherAI/pythia-70m", "chatml_format"),
+ ("bigscience/bloom-560m", "chatml_format"),
+ ("deepseek-ai/deepseek-coder-6.7b-base", "deepseek"),
+ ],
+ )
+ def test_assistant_mask_model_families(self, model_name, expected_template):
+ """Test assistant token masking support across different model families."""
+ from transformers import AutoTokenizer
+
+ torchrl_logger.info(f"\nTesting {model_name} with {expected_template} template")
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+
+ # Create a simple history
+ history = History.from_chats(
+ [
+ [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there!"},
+ ]
+ ]
+ )
+
+ # Test with expected template
+ result = history.apply_chat_template(
+ tokenizer=tokenizer,
+ chat_template_name=expected_template,
+ add_generation_prompt=False,
+ return_dict=True,
+ return_assistant_tokens_mask=True,
+ )
+
+ # Verify assistant mask is present
+ assert (
+ "assistant_masks" in result
+ ), f"Model {model_name} should support assistant masking"
+ assert (
+ result["assistant_masks"].shape[0] == 1
+ ), "Should have batch dimension of 1"
+ assert result["assistant_masks"].shape[1] > 0, "Should have sequence length > 0"
+
+ # Verify some assistant tokens are masked
+ assistant_token_count = result["assistant_masks"].sum().item()
+ assert (
+ assistant_token_count > 0
+ ), f"Model {model_name} should have assistant tokens masked"
+ torchrl_logger.info(
+ f" ✓ {model_name}: {assistant_token_count} assistant tokens masked"
+ )
+
+ @pytest.mark.parametrize(
+ "template_name", ["qwen", "dialogpt", "falcon", "deepseek"]
+ )
+ def test_assistant_mask_with_custom_templates(self, template_name):
+ """Test that models with custom templates can still use assistant masking."""
+ from transformers import AutoTokenizer
+
+ # Test Qwen with its custom template
+ tokenizer = AutoTokenizer.from_pretrained(
+ "Qwen/Qwen2.5-0.5B", trust_remote_code=True
+ )
+
+ history = History.from_chats(
+ [
+ [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there!"},
+ ]
+ ]
+ )
+
+ # Test with Qwen's custom template
+ result = history.apply_chat_template(
+ tokenizer=tokenizer,
+ chat_template_name=template_name,
+ add_generation_prompt=False,
+ return_dict=True,
+ return_assistant_tokens_mask=True,
+ )
+
+ assert "assistant_masks" in result
+ assert result["assistant_masks"].sum().item() > 0
+
+ @pytest.mark.parametrize(
+ "model_name, template_name",
+ [
+ ("Qwen/Qwen2.5-0.5B", "qwen"),
+ ("microsoft/DialoGPT-medium", "dialogpt"),
+ ("tiiuae/falcon-7b-instruct", "falcon"),
+ ("deepseek-ai/deepseek-coder-6.7b-base", "deepseek"),
+ ],
+ )
+ def test_custom_template_equivalence(self, model_name, template_name):
+ """Test that our custom templates produce the same output as the model's default template (except for masking)."""
+ import re
+
+ import transformers
+
+ # Simple multi-turn chat for each model
+ def norm(s):
+ if isinstance(s, list):
+ return [re.sub(r"\s+", " ", x.strip()) for x in s]
+ elif isinstance(s, str):
+ return re.sub(r"\s+", " ", s.strip())
+ else:
+ return s
+
+ chat = [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there!"},
+ {"role": "user", "content": "How are you?"},
+ {"role": "assistant", "content": "I'm good, thanks!"},
+ ]
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_name, trust_remote_code=True
+ )
+ history = History.from_chats([chat])
+
+ # Output with model's default template
+ try:
+ default_out = history.apply_chat_template(
+ tokenizer=tokenizer,
+ add_generation_prompt=False,
+ chat_template=tokenizer.chat_template, # Use model's default
+ chat_template_name=None,
+ tokenize=False,
+ )
+ except Exception as e:
+ default_out = None
+ torchrl_logger.info(
+ f"[WARN] Could not get default template for {model_name}: {e}"
+ )
+
+ # Output with our custom template
+ custom_out = history.apply_chat_template(
+ tokenizer=tokenizer,
+ add_generation_prompt=False,
+ chat_template_name=template_name,
+ chat_template=None,
+ tokenize=False,
+ )
+
+ if default_out is not None:
+ assert norm(default_out) == norm(custom_out), (
+ f"Custom template for {model_name} does not match default!\n"
+ f"Default: {default_out}\nCustom: {custom_out}"
+ )
+ else:
+ torchrl_logger.info(
+ f"[INFO] Skipped equivalence check for {model_name} (no default template available)"
+ )
+
+ def test_add_chat_template_parameters_used(self):
+ """Test that add_chat_template actually uses inverse_parser and model_family_keywords parameters with a real tokenizer."""
+ import re
+ import uuid
+
+ from torchrl.data.llm.history import add_chat_template, History
+ from transformers import AutoTokenizer
+
+ try:
+ # Track if the inverse parser is called
+ inverse_parser_called = {"called": False}
+
+ template_name = f"qwen_custom_{uuid.uuid4()}"
+
+ # Create a custom template (trivially different from Qwen)
+ custom_template = """
+ {% for message in messages %}
+ {%- if message['role'] == 'user' %}
+ [USER] {{ message['content'] }}
+ {%- elif message['role'] == 'assistant' %}
+ {% generation %}[ASSISTANT] {{ message['content'] }}{% endgeneration %}
+ {%- endif %}
+ {% endfor %}
+ """
+
+ # Custom inverse parser
+ def custom_inverse_parser(text: str) -> History:
+ inverse_parser_called["called"] = True
+ user_msgs = re.findall(
+ r"\[USER\] (.*?)(?=\[ASSISTANT\]|$)", text, re.DOTALL
+ )
+ assistant_msgs = re.findall(
+ r"\[ASSISTANT\] (.*?)(?=\[USER\]|$)", text, re.DOTALL
+ )
+ messages = []
+ for i, user_content in enumerate(user_msgs):
+ messages.append(History(role="user", content=user_content.strip()))
+ if i < len(assistant_msgs):
+ messages.append(
+ History(role="assistant", content=assistant_msgs[i].strip())
+ )
+ return lazy_stack(messages)
+
+ # Register the custom template and parser for Qwen
+ add_chat_template(
+ template_name=template_name,
+ template=custom_template,
+ inverse_parser=custom_inverse_parser,
+ model_family_keywords=["qwen"],
+ )
+
+ # Use a real Qwen tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ "Qwen/Qwen2.5-3B", trust_remote_code=True
+ )
+ history = History.from_chats(
+ [
+ [
+ {"role": "user", "content": "Hello"},
+ {"role": "assistant", "content": "Hi there!"},
+ ]
+ ]
+ )
+
+ # This should trigger auto-detection using our custom template
+ result = history.apply_chat_template(
+ tokenizer=tokenizer,
+ add_generation_prompt=False,
+ tokenize=False,
+ )
+ # The result should use our custom format
+ if isinstance(result, list):
+ result_str = result[0]
+ else:
+ result_str = result
+ assert "[USER]" in result_str
+ assert "[ASSISTANT]" in result_str
+
+ # Test that inverse parser works
+ parsed = History.from_text(result, chat_template_name=template_name)
+ assert inverse_parser_called["called"], "Inverse parser was not called"
+ assert parsed.role == history.role
+ assert parsed.content == history.content
+ finally:
+ from torchrl.data.llm.history import (
+ _CHAT_TEMPLATES,
+ _CUSTOM_INVERSE_PARSERS,
+ _CUSTOM_MODEL_FAMILY_KEYWORDS,
+ )
+
+ if template_name in _CHAT_TEMPLATES:
+ del _CHAT_TEMPLATES[template_name]
+ if template_name in _CUSTOM_INVERSE_PARSERS:
+ del _CUSTOM_INVERSE_PARSERS[template_name]
+ if template_name in _CUSTOM_MODEL_FAMILY_KEYWORDS:
+ del _CUSTOM_MODEL_FAMILY_KEYWORDS[template_name]
+
+ chats_round_trip = [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is the capital of France?"},
+ {"role": "assistant", "content": "The capital of France is Paris."},
+ ],
+ [
+ {"role": "user", "content": "Tell me a joke."},
+ {
+ "role": "assistant",
+ "content": "Why did the chicken cross the road? To get to the other side!",
+ },
+ ],
+ [
+ {"role": "system", "content": "You are a coding assistant."},
+ {"role": "user", "content": "Write a Python function to add two numbers."},
+ {"role": "assistant", "content": "def add(a, b):\n return a + b"},
+ {"role": "user", "content": "What about subtraction?"},
+ {"role": "assistant", "content": "def subtract(a, b):\n return a - b"},
+ ],
+ ]
+
+ @pytest.mark.skipif(not _has_transformers, reason="requires transformers library")
+ @pytest.mark.parametrize(
+ "tokenizer_name",
+ [
+ "meta-llama/Llama-4-Scout-17B-16E-Instruct",
+ "Qwen/Qwen2.5-0.5B",
+ "microsoft/phi-2",
+ ],
+ )
+ @pytest.mark.parametrize(
+ "use_tokenizer_chat_template",
+ [False, True],
+ ids=["no_use_tokenizer_chat_template", "use_tokenizer_chat_template"],
+ )
+ @pytest.mark.parametrize("chat", chats_round_trip)
+ def test_history_round_trip(
+ self, tokenizer_name, use_tokenizer_chat_template, chat
+ ):
+ """Test round-trip conversion: History -> string -> History for various templates and tokenizers."""
+ import re
+
+ from transformers import AutoTokenizer
+
+ # Example chats
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_name, trust_remote_code=True
+ )
+
+ history = History.from_chats(chat)
+ if use_tokenizer_chat_template:
+ if (
+ not hasattr(tokenizer, "chat_template")
+ or tokenizer.chat_template is None
+ ):
+ pytest.skip(f"Tokenizer {tokenizer_name} does not have a chat template")
+ chat_template = tokenizer.chat_template
+ chat_template_name = None
+ else:
+ chat_template = None
+ chat_template_name = None # Let History auto-detect
+
+ # Serialize
+ chat_str = history.apply_chat_template(
+ tokenizer=tokenizer,
+ add_generation_prompt=False,
+ chat_template=chat_template,
+ chat_template_name=chat_template_name,
+ return_dict=False,
+ )
+ # Parse back
+ parsed = History.from_text(
+ text=chat_str,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ chat_template_name=chat_template_name,
+ )
+
+ # Normalize whitespace for comparison
+ def norm(x):
+ if isinstance(x, list):
+ return [re.sub(r"\\s+", " ", str(xx).strip()) for xx in x]
+ return re.sub(r"\\s+", " ", str(x).strip())
+ # Compare roles and content
+ assert norm(parsed.role) == norm(
+ history.role
+ ), f"Roles do not match!\nOriginal: {history.role}\nParsed: {parsed.role}"
+ assert norm(parsed.content) == norm(
+ history.content
+ ), f"Content does not match!\nOriginal: {history.content}\nParsed: {parsed.content}"
+
+ # All messages should be complete
+ assert all(
+ parsed.is_complete
+ ), f"All messages should be complete after round-trip. is_complete: {parsed.is_complete}"
+
+ @pytest.mark.skipif(not _has_transformers, reason="requires transformers library")
+ @pytest.mark.parametrize(
+ "tokenizer_name",
+ [
+ "Qwen/Qwen2.5-0.5B",
+ "microsoft/phi-2",
+ "meta-llama/Llama-4-Scout-17B-16E-Instruct",
+ ],
+ )
+ @pytest.mark.parametrize(
+ "use_tokenizer_chat_template",
+ [False, True],
+ ids=["no_use_tokenizer_chat_template", "use_tokenizer_chat_template"],
+ )
+ @pytest.mark.parametrize("chat", chats_round_trip)
+ def test_history_round_trip_incomplete(
+ self, tokenizer_name, use_tokenizer_chat_template, chat
+ ):
+ """Test that truncated strings are properly parsed with the last message marked as incomplete."""
+ if chat[0]["role"] != "system":
+ pytest.skip("Skipping test for non-system message")
+ import re
+
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ tokenizer_name, trust_remote_code=True
+ )
+
+ history = History.from_chats(chat)
+
+ if use_tokenizer_chat_template:
+ if (
+ not hasattr(tokenizer, "chat_template")
+ or tokenizer.chat_template is None
+ ):
+ pytest.skip(f"Tokenizer {tokenizer_name} does not have a chat template")
+ chat_template = tokenizer.chat_template
+ chat_template_name = None
+ else:
+ chat_template = None
+ chat_template_name = None # Let History auto-detect
+
+ # Serialize
+ chat_str = history.apply_chat_template(
+ tokenizer=tokenizer,
+ add_generation_prompt=False,
+ chat_template=chat_template,
+ chat_template_name=chat_template_name,
+ return_dict=False,
+ )
+
+ # Truncate the last 10 characters to simulate incomplete response
+ truncated_chat_str = chat_str[:-10]
+
+ # Parse back the truncated string
+ parsed = History.from_text(
+ text=truncated_chat_str,
+ tokenizer=tokenizer,
+ chat_template=chat_template,
+ chat_template_name=chat_template_name,
+ )
+
+ # Normalize whitespace for comparison
+ def norm(x):
+ if isinstance(x, list):
+ return [re.sub(r"\\s+", " ", str(xx).strip()) for xx in x]
+ return re.sub(r"\\s+", " ", str(x).strip())
+
+ # Check that we have the same number of messages as the original
+ assert len(parsed.role) == len(
+ history.role
+ ), f"Number of messages should match original. Original: {len(history.role)}, Parsed: {len(parsed.role)}"
+ assert len(parsed.content) == len(
+ history.content
+ ), f"Number of content items should match original. Original: {len(history.content)}, Parsed: {len(parsed.content)}"
+ assert len(parsed.is_complete) == len(
+ history.is_complete
+ ), f"Number of completion flags should match original. Original: {len(history.is_complete)}, Parsed: {len(parsed.is_complete)}"
+
+ # Check that all messages except the last one are complete
+ if len(parsed.is_complete) > 0:
+ assert all(
+ parsed.is_complete[:-1]
+ ), f"All messages except the last should be complete. is_complete: {parsed.is_complete}"
+ assert not parsed.is_complete[
+ -1
+ ], f"Last message should be incomplete. is_complete: {parsed.is_complete}"
+
+ # Check that roles match the original (except potentially the last one if it was truncated mid-message)
+ assert norm(parsed.role[:-1]) == norm(
+ history.role[:-1]
+ ), f"All roles except the last should match original. Original: {history.role[:-1]}, Parsed: {parsed.role[:-1]}"
+
class TestTopK:
@pytest.mark.parametrize("per_token_reward", [True, False])
diff --git a/test/llm/test_envs.py b/test/llm/test_envs.py
index c0237ca73ff..03ade320da3 100644
--- a/test/llm/test_envs.py
+++ b/test/llm/test_envs.py
@@ -10,6 +10,7 @@
import random
import re
import time
+from functools import partial
import pytest
import torch
@@ -25,7 +26,7 @@
)
from torchrl._utils import logger as torchrl_logger
-from torchrl.data.llm.chat import History
+from torchrl.data.llm.history import History
from torchrl.envs import StepCounter
from torchrl.envs.llm import (
as_padded_tensor,
@@ -35,13 +36,15 @@
KLRewardTransform,
LLMEnv,
make_gsm8k_env,
+ RetrieveKL,
)
-from torchrl.modules.llm import TransformersWrapper
+from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
from transformers import AutoTokenizer
_has_transformers = importlib.util.find_spec("transformers") is not None
_has_datasets = importlib.util.find_spec("datasets") is not None
+_has_vllm = importlib.util.find_spec("vllm") is not None
_has_ifeval = (
_has_datasets
and (importlib.util.find_spec("langdetect") is not None)
@@ -50,6 +53,18 @@
)
+@pytest.fixture(scope="module", autouse=True)
+def set_seed():
+ seed = 2
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ yield
+
+
@pytest.fixture(scope="module", autouse=True)
def list_to_stack_fixture():
import tensordict
@@ -418,49 +433,91 @@ class TestChatEnv:
def tokenizer(self):
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
- def test_chat_env(slef, tokenizer):
+ @pytest.mark.parametrize("input_mode", ["text", "tokens", "history"])
+ def test_chat_env(self, tokenizer, input_mode):
# Set list to stack for tensordict
set_list_to_stack(True).set()
# Initialize environment
env = ChatEnv(
batch_size=(1,),
tokenizer=tokenizer,
- apply_template=True,
system_prompt="I'm system, do what I want.",
+ input_mode=input_mode,
)
# Reset environment
- td_reset = env.reset(
- TensorDict(
- text=["I'm the user. I'm going to tell you a little about something."],
- batch_size=(1,),
- )
+ td_reset = TensorDict(
+ query=["I'm the user. I'm going to tell you a little about something."],
+ batch_size=(1,),
+ device=env.device,
)
+ td_reset = env.reset(td_reset)
# Check history after reset
- torchrl_logger.info(f'{td_reset["history"].content=}')
- assert len(td_reset["history"][0].content) == 2
- assert td_reset["history"][0, 0].content == "I'm system, do what I want."
- assert td_reset["history"][0, 1].content.startswith("I'm the user.")
- assert td_reset["history"][0].role == ["system", "user"]
+ if input_mode == "history":
+ torchrl_logger.info(f'{td_reset["history"].prompt.content=}')
+ assert len(td_reset["history"][0].prompt.content) == 2
+ assert (
+ td_reset["history"][0].prompt[0].content
+ == "I'm system, do what I want."
+ )
+ assert td_reset["history"][0].prompt[1].content.startswith("I'm the user.")
+ assert td_reset["history"][0].prompt.role == ["system", "user"]
+ elif input_mode == "tokens":
+ torchrl_logger.info(f'{td_reset["tokens"].prompt=}')
+ elif input_mode == "text":
+ torchrl_logger.info(f'{td_reset["text"].prompt=}')
# Check text after reset
expected_text = "<|im_start|>system\nI'm system, do what I want.<|im_end|>\n<|im_start|>user\nI'm the user. I'm going to tell you a little about something.<|im_end|>\n<|im_start|>assistant\n"
- assert td_reset["text"][0] == expected_text
+ if input_mode in ("text",):
+ assert td_reset["text"][0].prompt == expected_text
# Take step in environment
- td_action = td_reset.set(
- "text_response", ["This is the action from the assistant!<|im_end|>"]
- )
+ if input_mode == "history":
+ td_reset["history"].response = History(
+ content="This is the action from the assistant!", role="assistant"
+ ).view(1, 1)
+ td_reset["history"].full = td_reset["history"].prompt.extend(
+ td_reset["history"].response, dim=-1
+ )
+ td_action = td_reset
+ elif input_mode == "tokens":
+ td_reset["tokens"][0].response = tokenizer.encode(
+ "This is the action from the assistant!<|im_end|>"
+ )
+ td_action = td_reset
+ elif input_mode == "text":
+ td_reset["text"].response = [
+ "This is the action from the assistant!<|im_end|>"
+ ]
+ td_reset["text"].full = [
+ td_reset["text"][0].prompt
+ + "This is the action from the assistant!<|im_end|>"
+ ]
+ td_action = td_reset
td_next = env.step(td_action)
- # Check history after step
- assert len(td_next["next", "history"].content[0]) == 3
- assert td_next["next", "history"][0, 0].content == "I'm system, do what I want."
- assert td_next["next", "history"][0, 1].content.startswith("I'm the user.")
- assert (
- td_next["next", "history"][0, 2].content
- == "This is the action from the assistant!"
- )
- assert td_next["next", "history"][0].role == ["system", "user", "assistant"]
- # Check text after step
- expected_text = "<|im_start|>system\nI'm system, do what I want.<|im_end|>\n<|im_start|>user\nI'm the user. I'm going to tell you a little about something.<|im_end|>\n<|im_start|>assistant\nThis is the action from the assistant!<|im_end|>\n<|im_start|>assistant\n"
- assert td_next["next", "text"][0] == expected_text
+ if input_mode == "history":
+ # Check history after step
+ assert len(td_next["next", "history"][0].prompt.content) == 3
+ assert (
+ td_next["next", "history"][0].prompt[0].content
+ == "I'm system, do what I want."
+ )
+ assert (
+ td_next["next", "history"][0]
+ .prompt[1]
+ .content.startswith("I'm the user.")
+ )
+ assert (
+ td_next["next", "history"][0].prompt[2].content
+ == "This is the action from the assistant!"
+ )
+ assert td_next["next", "history"][0].prompt.role == [
+ "system",
+ "user",
+ "assistant",
+ ]
+ if input_mode in ("text",):
+ # Check text after step
+ expected_text = "<|im_start|>system\nI'm system, do what I want.<|im_end|>\n<|im_start|>user\nI'm the user. I'm going to tell you a little about something.<|im_end|>\n<|im_start|>assistant\nThis is the action from the assistant!<|im_end|>"
+ assert td_next["next", "text"][0].prompt == expected_text
@pytest.mark.skipif(not _has_datasets, reason="requires datasets")
@@ -506,45 +563,6 @@ def test_env_reward(self, n_envs):
assert ("next", "reward") in r
assert r["next", "reward"].shape == (n_envs, 3, 1, 1)
- @pytest.mark.skipif(not _has_transformers, reason="requires transformers library")
- @pytest.mark.parametrize("n_envs", [1, 4])
- def test_kl_bonus(self, n_envs, ref_model):
- torch.manual_seed(0)
- ref_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- with torch.device(ref_device):
- model, tokenizer = ref_model
- ref_model = TransformersWrapper(
- model,
- return_log_probs=True,
- generate=False,
- # In practice, we should have the tokens available
- from_text=False,
- tokenizer=tokenizer,
- )
- policy = TransformersWrapper(
- model,
- return_log_probs=True,
- generate=True,
- from_text=True,
- tokenizer=tokenizer,
- generate_kwargs={"max_new_tokens": 20},
- tokenizer_kwargs={"add_special_tokens": False},
- )
-
- env = make_gsm8k_env(num_envs=n_envs, tokenizer=tokenizer)
- env.append_transform(
- KLRewardTransform(
- actor=ref_model,
- coef=0.1,
- device=ref_device,
- )
- )
- r = env.rollout(3, policy)
- r = r.view(-1)
- for _r in r.unbind(0):
- assert _r["tokens_response"].shape + (1,) == _r["next", "reward"].shape
-
def test_gsm8kenv(self):
import transformers
@@ -553,34 +571,22 @@ def test_gsm8kenv(self):
# env.check_env_specs(break_when_any_done="both")
r = env.reset()
assert "history" in r
- assert r["history"].shape == (1, 2)
- assert "text" in r
+ assert r["history"].prompt.shape == (1, 2)
r = r.clone()
response = "First, calculate the total number of snakes in the breeding balls. There are 3 breeding balls with 8 snakes each, so 3 * 8 = 24 snakes. Next, calculate the number of snakes in the additional pairs. There are 6 pairs of snakes, and each pair has 2 snakes, so 6 * 2 = 12 snakes. Finally, add the number of snakes from the breeding balls and the additional pairs: 24 + 12 = 36 snakes. Mary saw a total of 36 snakes. <|im_end|>"
- r["text_response"] = [response]
+ text = (
+ r["history"]
+ .prompt[0]
+ .apply_chat_template(tokenizer=tokenizer, add_generation_prompt=True)
+ + response
+ )
+ history_full = History.from_text(text).unsqueeze(0)
+ assert history_full.shape[-1] == 3
+ r["history"].full = history_full
s = env.step(r)
assert s["next", "reward"] >= 10
assert s["next", "done"].all()
- def test_gsm8kenv_batch(self):
- import transformers
-
- tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
- env = GSM8KEnv(tokenizer=tokenizer, apply_template=True, num_envs=4)
- # env.check_env_specs(break_when_any_done="both")
- r = env.reset()
- assert "history" in r
- assert r["history"].shape == (4, 2)
- assert "text" in r
- r = r.clone()
- response = "First, calculate the total number of snakes in the breeding balls. There are 3 breeding balls with 8 snakes each, so 3 * 8 = 24 snakes. Next, calculate the number of snakes in the additional pairs. There are 6 pairs of snakes, and each pair has 2 snakes, so 6 * 2 = 12 snakes. Finally, add the number of snakes from the breeding balls and the additional pairs: 24 + 12 = 36 snakes. Mary saw a total of 36 snakes. <|im_end|>"
- r["text_response"] = [response] * 4
- s = env.step(r)
- assert (s["next", "reward"] >= 10).all()
- assert s["next", "done"].all()
-
- env.rollout(10, break_when_any_done=False)
-
@pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs")
class TestIFEvalEnv:
@@ -592,13 +598,14 @@ def test_ifeval(self):
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
- env = IFEvalEnv(apply_template=True, tokenizer=tokenizer)
+ env = IFEvalEnv(apply_template=True, tokenizer=tokenizer, input_mode="history")
torchrl_logger.info(env.reset())
r = env.reset()
- r.set(
- "text_response",
- [
- """
+ r["history"].full = History.from_text(
+ r["history"]
+ .prompt[0]
+ .apply_chat_template(tokenizer=tokenizer, add_generation_prompt=True)
+ + """
The task requires crafting a riddle about a 'house' that's not traditionally considered one. The answer must be included, and the response should be at least 400 words with a title wrapped in double angular brackets. Let's start by brainstorming what could be considered a 'house' in a non-traditional sense. Ideas include natural shelters, abstract concepts, or objects that serve a similar purpose to a house.
One potential concept is a "womb," as it provides shelter and housing for a developing being. However, we need to ensure our riddle is engaging, meets the word count requirement, and includes the necessary elements like a title.
Let's construct a narrative around the chosen concept, ensuring it's detailed and follows the required structure.
@@ -640,8 +647,7 @@ def test_ifeval(self):
By embracing such metaphors, we're encouraged to look beyond the obvious and appreciate the myriad ways 'shelter' manifests in our lives. And so, the riddle serves not just as a puzzle to be solved but as a reflection on the profound connections that bind us to the very essence of existence.
<|im_end|>
"""
- ],
- )
+ ).unsqueeze(0)
td = env.step(r)
assert td["next", "ifeval_score"].all()
assert td.get(("next", "reward")) is not None
@@ -660,26 +666,32 @@ def test_python_interpreter_single_batch(self):
base_env = ChatEnv(
batch_size=(1,),
system_prompt="I'm the system, do as I say",
- apply_template=True,
tokenizer=tokenizer,
+ input_mode="history",
)
env = base_env.append_transform(PythonInterpreter())
- r = env.reset(TensorDict(text=["This is the user prompt"], batch_size=(1,)))
+ r = env.reset(
+ TensorDict(
+ {base_env.data_key: ["This is the user prompt"]}, batch_size=(1,)
+ )
+ )
rc = r.clone()
- h = r["history"]
+ h = r["history"].prompt
history_from_text = h.apply_chat_template(tokenizer=tokenizer)
assert history_from_text == [
"<|im_start|>system\nI'm the system, do as I say<|im_end|>\n<|im_start|>user\nThis is the user prompt<|im_end|>\n<|im_start|>assistant\n"
]
- r["text_response"] = [
- """Here is a python code to execute:
-```python
-print(1 + 1)
-```<|im_end|>\n
-"""
- ]
+ r["history"].full = h.extend(
+ History(
+ role="assistant",
+ content="Here is a python code to execute:\n```python\nprint(1 + 1)\n```",
+ ).view(1, 1),
+ dim=-1,
+ )
s = env.step(r)
- history_str = s["next", "history"].apply_chat_template(tokenizer=tokenizer)
+ history_str = s["next", "history"].prompt.apply_chat_template(
+ tokenizer=tokenizer, add_generation_prompt=True
+ )
assert history_str == [
"<|im_start|>system\n"
"I'm the system, do as I say<|im_end|>\n"
@@ -690,7 +702,7 @@ def test_python_interpreter_single_batch(self):
"```python\n"
"print(1 + 1)\n"
"```<|im_end|>\n"
- "<|im_start|>user\n"
+ " <|im_start|>user\n"
"\n"
"Code block 1 executed successfully:\n"
"2\n"
@@ -719,22 +731,35 @@ def test_python_interpreter_single_batch(self):
).all()
# Check what happens if there is no tool response
r = rc.clone()
- r["text_response"] = [
- """Here is a response without a python code to execute.<|im_end|>"""
- ]
+ r["history"].full = h.extend(
+ History(
+ role="assistant",
+ content="Here is a response without a python code to execute.",
+ ).view(1, 1),
+ dim=-1,
+ )
s = env.step(r)
- history_str = s["next", "history"].apply_chat_template(tokenizer=tokenizer)
+ history_str = s["next", "history"].prompt.apply_chat_template(
+ tokenizer=tokenizer, add_generation_prompt=True
+ )
assert history_str == [
"<|im_start|>system\n"
"I'm the system, do as I say<|im_end|>\n"
"<|im_start|>user\n"
"This is the user prompt<|im_end|>\n"
"<|im_start|>assistant\n"
+ "Here is a python code to execute:\n"
+ "```python\n"
+ "print(1 + 1)\n"
+ "```<|im_end|>\n"
+ " <|im_start|>assistant\n"
"Here is a response without a python code to execute.<|im_end|>\n"
- "<|im_start|>assistant\n"
+ " <|im_start|>assistant\n"
]
def test_python_interpreter_persistent(self):
+ pass
+
from torchrl.envs.llm.transforms import PythonInterpreter
from transformers import AutoTokenizer
@@ -742,29 +767,35 @@ def test_python_interpreter_persistent(self):
env = ChatEnv(
batch_size=(1,),
system_prompt="I'm the system, do as I say",
- apply_template=True,
tokenizer=tokenizer,
+ input_mode="history",
)
env = env.append_transform(PythonInterpreter(persistent=True))
- r = env.reset(TensorDict(text=["This is the user prompt"], batch_size=(1,)))
- r["text_response"] = [
- """Here is a python code to execute:
-```python
-a=1
-```<|im_end|>\n
-"""
- ]
+ r = env.reset(
+ TensorDict({env.data_key: ["This is the user prompt"]}, batch_size=(1,))
+ )
+ r["history"].full = r["history"].prompt.extend(
+ History(
+ role="assistant",
+ content="Here is a python code to execute:\n```python\na=1\n```",
+ ).view(1, 1),
+ dim=-1,
+ )
s, s_ = env.step_and_maybe_reset(r)
- s_["text_response"] = [
- """Here is a python code to execute:
-```python
-a+=1
-assert a == 2
-```<|im_end|>\n
-"""
- ]
+ s_["history"].full = s_["history"].prompt.extend(
+ History(
+ role="assistant",
+ content="Here is a python code to execute:\n```python\na+=1\nassert a == 2\n```",
+ ).view(1, 1),
+ dim=-1,
+ inplace=False,
+ )
s, s_ = env.step_and_maybe_reset(s_)
- assert s_["history"].apply_chat_template(tokenizer=tokenizer) == [
+ response = s_["history"].prompt.apply_chat_template(
+ tokenizer=tokenizer, add_generation_prompt=True
+ )
+
+ assert response == [
"<|im_start|>system\n"
"I'm the system, do as I say<|im_end|>\n"
"<|im_start|>user\n"
@@ -774,7 +805,7 @@ def test_python_interpreter_persistent(self):
"```python\n"
"a=1\n"
"```<|im_end|>\n"
- "<|im_start|>user\n"
+ " <|im_start|>user\n"
"\n"
"Code block 1 executed successfully:\n"
"\n"
@@ -785,7 +816,7 @@ def test_python_interpreter_persistent(self):
"a+=1\n"
"assert a == 2\n"
"```<|im_end|>\n"
- "<|im_start|>user\n"
+ " <|im_start|>user\n"
"\n"
"Code block 1 executed successfully:\n"
"\n"
@@ -801,30 +832,33 @@ def test_python_interpreter_persistent_error(self):
env = ChatEnv(
batch_size=(1,),
system_prompt="I'm the system, do as I say",
- apply_template=True,
tokenizer=tokenizer,
+ input_mode="history",
)
env = env.append_transform(PythonInterpreter(persistent=True))
- r = env.reset(TensorDict(text=["This is the user prompt"], batch_size=(1,)))
- r["text_response"] = [
- """Here is a python code to execute:
-```python
-raise ValueError("This is an error")
-```<|im_end|>\n
-"""
- ]
+ r = env.reset(
+ TensorDict({env.data_key: ["This is the user prompt"]}, batch_size=(1,))
+ )
+ r["history"].full = r["history"].prompt.extend(
+ History(
+ role="assistant",
+ content="Here is a python code to execute:\n```python\nraise ValueError('This is an error')\n```",
+ ).view(1, 1),
+ dim=-1,
+ )
s, s_ = env.step_and_maybe_reset(r)
- s_["text_response"] = [
- """Here is a python code to execute:
-```python
-a=1
-assert a == 1
-```<|im_end|>\n
-"""
- ]
+ s_["history"].full = s_["history"].prompt.extend(
+ History(
+ role="assistant",
+ content="Here is a python code to execute:\n```python\na=1\nassert a == 1\n```",
+ ).view(1, 1),
+ dim=-1,
+ )
s, s_ = env.step_and_maybe_reset(s_)
assert re.match(
- s_["history"].apply_chat_template(tokenizer=tokenizer)[0],
+ s_["history"].prompt.apply_chat_template(
+ tokenizer=tokenizer, add_generation_prompt=True
+ )[0],
r"<|im_start|>system\n"
"I'm the system, do as I say<|im_end|>\n"
"<|im_start|>user\n"
@@ -834,7 +868,7 @@ def test_python_interpreter_persistent_error(self):
"```python\n"
'raise ValueError("This is an error")\n'
"```<|im_end|>\n"
- "<|im_start|>user\n"
+ " <|im_start|>user\n"
"\n"
"Code block 1 failed:\n"
"Error: This is an error\n"
@@ -853,7 +887,7 @@ def test_python_interpreter_persistent_error(self):
"a=1\n"
"assert a == 1\n"
"```<|im_end|>\n"
- "<|im_start|>user\n"
+ " <|im_start|>user\n"
"\n"
"Code block 1 executed successfully:\n"
"\n"
@@ -869,34 +903,35 @@ def test_python_interpreter_persistent_reset(self):
env = ChatEnv(
batch_size=(1,),
system_prompt="I'm the system, do as I say",
- apply_template=True,
tokenizer=tokenizer,
)
env = env.append_transform(PythonInterpreter(persistent=True))
- r = env.reset(TensorDict(text=["This is the user prompt"], batch_size=(1,)))
- r["text_response"] = [
- """Here is a python code to execute:
-```python
-a = [0]
-```<|im_end|>\n
-"""
- ]
+ r = env.reset(
+ TensorDict({env.data_key: ["This is the user prompt"]}, batch_size=(1,))
+ )
+ r["history"].full = r["history"].prompt.extend(
+ History(
+ role="assistant",
+ content="Here is a python code to execute:\n```python\na = [0]\n```",
+ ).view(1, 1),
+ dim=-1,
+ )
s, s_ = env.step_and_maybe_reset(r)
- r = env.reset(TensorDict(text=["This is the user prompt"], batch_size=(1,)))
- r["text_response"] = [
- """Here is a python code to execute:
-```python
-# check if a is still defined
-if "a" in globals():
- raise RuntimeError("a is still defined")
-else:
- print("a is not defined")
-```<|im_end|>\n
-"""
- ]
+ r = env.reset(
+ TensorDict({env.data_key: ["This is the user prompt"]}, batch_size=(1,))
+ )
+ r["history"].full = r["history"].prompt.extend(
+ History(
+ role="assistant",
+ content="Here is a python code to execute:\n```python\n# check if a is still defined\nif 'a' in globals():\n raise RuntimeError('a is still defined')\nelse:\n print('a is not defined')\n```",
+ ).view(1, 1),
+ dim=-1,
+ )
s, s_ = env.step_and_maybe_reset(r)
assert re.match(
- s_["history"].apply_chat_template(tokenizer=tokenizer)[0],
+ s_["history"].prompt.apply_chat_template(
+ tokenizer=tokenizer, add_generation_prompt=True
+ )[0],
"<|im_start|>system\n"
"I'm the system, do as I say<|im_end|>\n"
"<|im_start|>user\n"
@@ -910,7 +945,7 @@ def test_python_interpreter_persistent_reset(self):
"else:\n"
' print("a is not defined")\n'
"```<|im_end|>\n"
- "<|im_start|>user\n"
+ " <|im_start|>user\n"
"\n"
"Code block 1 executed successfully:\n"
"a is not defined\n"
@@ -956,41 +991,50 @@ def calculator(operation: str, a: float, b: float) -> dict:
# Create environment and transform
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
- env = ChatEnv(
+ base_env = ChatEnv(
batch_size=(1,),
system_prompt="You are a helpful assistant that uses a calculator.",
- apply_template=True,
tokenizer=tokenizer,
)
transform = MCPToolTransform(tools, schemas)
- env = env.append_transform(transform)
+ env = base_env.append_transform(transform)
# Test single tool call
- td = TensorDict({"text": ["Let me calculate 2 + 3"]}, batch_size=(1,))
+ td = TensorDict(
+ {base_env.data_key: ["Let me calculate 2 + 3"]}, batch_size=(1,)
+ )
td = env.reset(td)
- td["text_response"] = [
- 'I will help you calculate 2 + 3:\ncalculator\n{"operation": "add", "a": 2, "b": 3} <|im_end|>'
- ]
+ td["history"].full = td["history"].prompt.extend(
+ History(
+ role="assistant",
+ content='I will help you calculate 2 + 3:\ncalculator\n{"operation": "add", "a": 2, "b": 3} <|im_end|>',
+ ).view(1, 1),
+ dim=-1,
+ )
result = env.step(td)
# Check that the tool was executed and returned correct result
- history = result["next", "history"]
+ history = result["next", "history"].prompt
assert len(history[0]) == 4 # system, user, assistant, tool response
assert history[0, -1].role == "tool"
assert "result': 5" in history[0, -1].content
# Test multiple tool calls in one response
- td = TensorDict({"text": ["Calculate 2 + 3 and 4 * 5"]}, batch_size=(1,))
+ td = TensorDict(
+ {base_env.data_key: ["Calculate 2 + 3 and 4 * 5"]}, batch_size=(1,)
+ )
td = env.reset(td)
- td["text_response"] = [
- "I will help you calculate both:\n"
- 'calculator\n{"operation": "add", "a": 2, "b": 3} \n'
- 'calculator\n{"operation": "multiply", "a": 4, "b": 5} <|im_end|>'
- ]
+ td["history"].full = td["history"].prompt.extend(
+ History(
+ role="assistant",
+ content='I will help you calculate both:\ncalculator\n{"operation": "add", "a": 2, "b": 3} \ncalculator\n{"operation": "multiply", "a": 4, "b": 5} <|im_end|>',
+ ).view(1, 1),
+ dim=-1,
+ )
result = env.step(td)
# Check that both tools were executed and returned correct results
- history = result["next", "history"]
+ history = result["next", "history"].prompt
assert (
len(history[0]) == 5
) # system, user, assistant, tool response 1, tool response 2
@@ -1000,30 +1044,38 @@ def calculator(operation: str, a: float, b: float) -> dict:
assert "result': 20" in history[0, -1].content # 4 * 5 = 20
# Test error handling
- td = TensorDict({"text": ["Calculate 2 ? 3"]}, batch_size=(1,))
+ td = TensorDict({base_env.data_key: ["Calculate 2 ? 3"]}, batch_size=(1,))
td = env.reset(td)
- td["text_response"] = [
- 'I will try to calculate:\ncalculator\n{"operation": "invalid", "a": 2, "b": 3} <|im_end|>'
- ]
+ td["history"].full = td["history"].prompt.extend(
+ History(
+ role="assistant",
+ content='I will try to calculate:\ncalculator\n{"operation": "invalid", "a": 2, "b": 3} <|im_end|>',
+ ).view(1, 1),
+ dim=-1,
+ )
result = env.step(td)
# Check that error was handled gracefully
- history = result["next", "history"]
+ history = result["next", "history"].prompt
assert len(history[0]) == 4
assert history[0, -1].role == "tool"
assert "failed" in history[0, -1].content
assert "Unknown operation: invalid" in history[0, -1].content
# Test invalid JSON
- td = TensorDict({"text": ["Calculate something"]}, batch_size=(1,))
+ td = TensorDict({base_env.data_key: ["Calculate something"]}, batch_size=(1,))
td = env.reset(td)
- td["text_response"] = [
- "Let me calculate:\ncalculator\ninvalid json <|im_end|>"
- ]
+ td["history"].full = td["history"].prompt.extend(
+ History(
+ role="assistant",
+ content="Let me calculate:\ncalculator\ninvalid json <|im_end|>",
+ ).view(1, 1),
+ dim=-1,
+ )
result = env.step(td)
# Check that JSON error was handled gracefully
- history = result["next", "history"]
+ history = result["next", "history"].prompt
assert len(history[0]) == 4
assert history[0, -1].role == "tool"
assert "failed" in history[0, -1].content
@@ -1066,7 +1118,6 @@ def make_env(cls):
env = ChatEnv(
batch_size=(1,),
system_prompt="I'm a calculator assistant",
- apply_template=True,
tokenizer=tokenizer,
)
tools = {"calculator": cls.delayed_calculator}
@@ -1086,20 +1137,30 @@ def test_async_mcp_tools(self):
try:
# Reset both environments
tdreset = TensorDict(
- text=[["Let me calculate 2 + 3"], ["Let me calculate 4 * 5"]],
+ query=[["Let me calculate 2 + 3"], ["Let me calculate 4 * 5"]],
batch_size=(2, 1),
)
td = env_pool.reset(tdreset)
# Send async steps to both environments
- td["text_response"] = [
+ td["history"].full = torch.stack(
[
- 'Let me calculate 2 + 3:\ncalculator\n{"operation": "add", "a": 2, "b": 3} <|im_end|>'
- ],
- [
- 'Let me calculate 4 * 5:\ncalculator\n{"operation": "multiply", "a": 4, "b": 5} <|im_end|>'
- ],
- ]
+ td[0]["history"].prompt.extend(
+ History(
+ role="assistant",
+ content='Let me calculate 2 + 3:\ncalculator\n{"operation": "add", "a": 2, "b": 3} <|im_end|>',
+ ).view(1, 1),
+ dim=-1,
+ ),
+ td[1]["history"].prompt.extend(
+ History(
+ role="assistant",
+ content='Let me calculate 4 * 5:\ncalculator\n{"operation": "multiply", "a": 4, "b": 5} <|im_end|>',
+ ).view(1, 1),
+ dim=-1,
+ ),
+ ]
+ )
env_pool.async_step_send(td)
# Get results as they complete
@@ -1116,7 +1177,7 @@ def test_async_mcp_tools(self):
all_results = torch.stack(list(results) + list(remaining))
# Verify results
- history = all_results["next", "history"]
+ history = all_results["next", "history"].prompt
assert len(history[0, 0]) == 4 # system, user, assistant, tool response
assert history[0, 0, -1].role == "tool"
assert any(
@@ -1174,14 +1235,18 @@ def test_thinking_prompt_wrong_answer(
)
)
reset = env.reset()
- assert reset[0]["history"][-1].content.startswith(
- "Natalia sold clips to 48 of her friends in April"
+ assert (
+ reset[0]["history"]
+ .prompt[-1]
+ .content.startswith("Natalia sold clips to 48 of her friends in April")
)
- policy_anser = (
+ policy_answer = (
"Let me solve this step by step. Natalia sold clips to 48 friends in April. Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
"To find the total, I need to add April and May: 48 + 24 = 72. Therefore, Natalia sold 72 clips altogether in April and May. \n322 clips <|im_end|>"
)
- reset["text_response"] = [policy_anser]
+ reset["history"].full = reset["history"].prompt.extend(
+ History(role="assistant", content=policy_answer).view(1, 1), dim=-1
+ )
s = env.step(reset)
if zero_reward:
assert (s["next", "reward"] == 0).all()
@@ -1192,13 +1257,13 @@ def test_thinking_prompt_wrong_answer(
else:
assert (s["next", "done"] != 0).all()
if edit_last_turn:
- assert s["next", "history"].shape == (1, 3)
+ assert s["next", "history"].prompt.shape == (1, 3)
else:
- assert s["next", "history"].shape == (1, 4)
+ assert s["next", "history"].prompt.shape == (1, 4)
if role == "assistant":
- assert s[0]["next", "history", "role"][-1] == "assistant"
+ assert s[0]["next", "history"].prompt[-1].role == "assistant"
else:
- assert s[0]["next", "history", "role"][-1] == "user"
+ assert s[0]["next", "history"].prompt[-1].role == "user"
@pytest.mark.skipif(not _has_transformers, reason="requires transformers")
@pytest.mark.skipif(not _has_datasets, reason="requires gsm8k")
@@ -1237,19 +1302,269 @@ def test_thinking_prompt_correct_answer(
)
)
reset = env.reset()
- assert reset[0]["history"][-1].content.startswith(
- "Natalia sold clips to 48 of her friends in April"
+ assert (
+ reset[0]["history"]
+ .prompt[-1]
+ .content.startswith("Natalia sold clips to 48 of her friends in April")
)
- policy_anser = (
+ policy_answer = (
"Let me solve this step by step. Natalia sold clips to 48 friends in April. Then she sold half as many in May. Half of 48 is 24. So in May she sold 24 clips. "
"To find the total, I need to add April and May: 48 + 24 = 72. Therefore, Natalia sold 72 clips altogether in April and May. \n72 <|im_end|>"
)
- reset["text_response"] = [policy_anser]
+ reset["history"].full = reset["history"].prompt.extend(
+ History(role="assistant", content=policy_answer).view(1, 1), dim=-1
+ )
s = env.step(reset)
assert (s["next", "reward"] != 0).all(), s["next", "reward"]
- assert s[0]["next", "history", "role"][-1] == "assistant"
+ assert s[0]["next", "history"].prompt[-1].role == "assistant"
assert s["next", "done"].all()
- assert len(s[0]["next", "history", "content"]) == 3
+ assert len(s[0]["next", "history"].prompt) == 3
+
+
+class TestChatEnvIntegration:
+ @pytest.fixture(scope="module")
+ def transformers_instance(self):
+ """Create transformers model and tokenizer for testing."""
+ if not _has_transformers:
+ pytest.skip("transformers not available")
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+ tokenizer.pad_token = tokenizer.eos_token
+ return model, tokenizer
+
+ @pytest.fixture(scope="module")
+ def vllm_instance(self):
+ """Create vLLM model and tokenizer for testing."""
+ if not _has_vllm:
+ pytest.skip("vllm not available")
+
+ import vllm.envs as envs
+ from transformers import AutoTokenizer
+ from vllm import LLM
+
+ envs.VLLM_HOST_IP = "0.0.0.0" or "127.0.0.1"
+
+ try:
+ model = LLM("Qwen/Qwen2.5-0.5B")
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+ tokenizer.pad_token = tokenizer.eos_token
+ return model, tokenizer
+ except Exception as e:
+ pytest.skip(f"Failed to load vLLM model: {e}")
+
+ @pytest.mark.skipif(not _has_vllm, reason="vllm not available")
+ @pytest.mark.skipif(not _has_datasets, reason="datasets not available")
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ @pytest.mark.parametrize(
+ "input_mode,compute_reward",
+ [["history", True], ["history", False], ["text", False], ["tokens", False]],
+ ids=[
+ "history_compute_reward",
+ "history_no_compute_reward",
+ "text_no_compute_reward",
+ "tokens_no_compute_reward",
+ ],
+ )
+ def test_chat_env_integration_ifeval(self, compute_reward, pad_output, input_mode):
+ """Test that the wrapper works correctly with the ChatEnv."""
+ import vllm.envs as envs
+ from torchrl.envs.llm import IFEvalEnv
+
+ envs.VLLM_HOST_IP = "0.0.0.0" or "127.0.0.1"
+
+ policy = vLLMWrapper(
+ model="Qwen/Qwen2.5-0.5B",
+ tokenizer="Qwen/Qwen2.5-0.5B",
+ input_mode=input_mode,
+ pad_output=pad_output,
+ generate=True,
+ )
+ env = IFEvalEnv(
+ max_steps=1,
+ compute_reward=compute_reward,
+ input_mode=input_mode,
+ tokenizer=policy.tokenizer,
+ )
+ r = env.reset()
+ prompt = None
+ if input_mode == "history":
+ assert r["history", "prompt"].shape == (1, 2)
+ elif input_mode == "text":
+ prompt = r["text", "prompt"][0]
+ r = policy(r)
+ if input_mode == "history":
+ assert r["history", "response"].shape == (1, 1)
+ assert r["history", "full"].shape == (1, 3)
+ elif input_mode == "text":
+ assert r["text", "full"][0].startswith(prompt)
+ r, r_ = env.step_and_maybe_reset(r)
+ if input_mode == "history":
+ assert r["next", "history", "prompt"].shape == (1, 3)
+ assert r_["history", "prompt"] is not None
+ assert r_.get(("history", "response"), as_list=True) is None
+ assert r_.get(("history", "full"), as_list=True) is None
+ assert r["next", "done"].all()
+ r = policy(r_)
+ r, r_ = env.step_and_maybe_reset(r)
+
+ @pytest.mark.skipif(not _has_vllm, reason="vllm not available")
+ @pytest.mark.skipif(not _has_datasets, reason="datasets not available")
+ @pytest.mark.parametrize(
+ "compute_reward", [False, True], ids=["no_compute_reward", "compute_reward"]
+ )
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ @pytest.mark.parametrize(
+ "input_mode", ["history", "text", "tokens"], ids=["history", "text", "tokens"]
+ )
+ def test_chat_env_integration_gsm8k(self, compute_reward, pad_output, input_mode):
+ """Test that the wrapper works correctly with the ChatEnv."""
+ import vllm.envs as envs
+ from torchrl.envs.llm import GSM8KEnv
+
+ envs.VLLM_HOST_IP = "0.0.0.0" or "127.0.0.1"
+
+ policy = vLLMWrapper(
+ model="Qwen/Qwen2.5-0.5B",
+ tokenizer="Qwen/Qwen2.5-0.5B",
+ input_mode=input_mode,
+ pad_output=pad_output,
+ generate=True,
+ )
+ env = GSM8KEnv(
+ max_steps=1,
+ compute_reward=compute_reward,
+ input_mode=input_mode,
+ tokenizer=policy.tokenizer,
+ )
+ r = env.reset()
+ prompt = None
+ if input_mode == "history":
+ assert r["history", "prompt"].shape == (1, 2)
+ elif input_mode == "text":
+ prompt = r["text", "prompt"][0]
+ r = policy(r)
+ if input_mode == "history":
+ assert r["history", "response"].shape == (1, 1)
+ assert r["history", "full"].shape == (1, 3)
+ elif input_mode == "text":
+ assert r["text", "full"][0].startswith(prompt)
+ r, r_ = env.step_and_maybe_reset(r)
+ if input_mode == "history":
+ assert r["next", "history", "prompt"].shape == (1, 3)
+ assert r_["history", "prompt"] is not None
+ assert r_.get(("history", "response"), as_list=True) is None
+ assert r_.get(("history", "full"), as_list=True) is None
+ assert r["next", "done"].all()
+ r = policy(r_)
+ r, r_ = env.step_and_maybe_reset(r)
+
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ @pytest.mark.parametrize("ref_input_mode", ["tokens"], ids=["tokens"])
+ @pytest.mark.parametrize(
+ "env_class", ["GSM8KEnv", "IFEvalEnv"], ids=["gsm8k", "ifeval"]
+ )
+ def test_chat_env_kl(
+ self,
+ transformers_instance,
+ vllm_instance,
+ pad_output,
+ ref_input_mode,
+ env_class,
+ ):
+ """Test that the wrapper works correctly with the ChatEnv."""
+ import vllm.envs as envs
+ from torchrl.envs.llm import GSM8KEnv, IFEvalEnv
+
+ envs.VLLM_HOST_IP = "0.0.0.0" or "127.0.0.1"
+
+ vllm_model, vllm_tokenizer = vllm_instance
+ tf_model, tf_tokenizer = transformers_instance
+
+ # a policy
+ policy = vLLMWrapper(
+ vllm_model,
+ tokenizer=vllm_tokenizer,
+ input_mode="history",
+ generate=True,
+ pad_output=pad_output,
+ )
+ ref_model = TransformersWrapper(
+ tf_model,
+ tokenizer=tf_tokenizer,
+ input_mode="tokens",
+ # TODO: check that generate=True causes an error
+ generate=False,
+ return_log_probs=True,
+ pad_output=pad_output,
+ )
+
+ if env_class == "GSM8KEnv":
+ env = GSM8KEnv(max_steps=10, num_envs=3, input_mode="history")
+ elif env_class == "IFEvalEnv":
+ env = IFEvalEnv(max_steps=10, num_envs=3, input_mode="history")
+ else:
+ raise ValueError(f"Invalid environment class: {env_class}")
+ env = env.append_transform(KLRewardTransform(ref_model))
+ r = env.rollout(1, policy)
+ reward = r.get(("next", "reward"), as_list=not pad_output)
+ assert reward is not None
+ if pad_output:
+ assert reward.shape[0] == 3
+ assert reward.shape[1] == 1
+ assert reward.shape[2] > 1
+ assert reward.shape[3] == 1
+ else:
+ assert len(reward) == 3
+ for r in reward:
+ assert r.shape[0] == 1
+ assert r.shape[1] > 1
+ assert r.shape[2] == 1
+
+ @pytest.mark.parametrize(
+ "env_class", ["GSM8KEnv", "IFEvalEnv"], ids=["gsm8k", "ifeval"]
+ )
+ def test_retrievekl_transform(
+ self, transformers_instance, vllm_instance, env_class
+ ):
+ """Test that the RetrieveKL transform works correctly."""
+ from torchrl.collectors.llm.base import LLMCollector
+ from torchrl.envs.llm import GSM8KEnv, IFEvalEnv
+
+ model, tokenizer = transformers_instance
+ vllm_model, vllm_tokenizer = vllm_instance
+ ref_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=False,
+ pad_output=True,
+ )
+ if env_class == "GSM8KEnv":
+ env = GSM8KEnv(max_steps=1, num_envs=3)
+ elif env_class == "IFEvalEnv":
+ env = IFEvalEnv(max_steps=1, num_envs=3)
+ else:
+ raise ValueError(f"Invalid environment class: {env_class}")
+ env = env.append_transform(RetrieveKL("from_collector", ref_model))
+ c = LLMCollector(
+ env,
+ policy_factory=partial(
+ vLLMWrapper,
+ vllm_model,
+ tokenizer=vllm_tokenizer,
+ input_mode="history",
+ generate=True,
+ pad_output=True,
+ ),
+ dialog_turns_per_batch=6,
+ )
+ for d in c:
+ assert ("history", "full") in d
+ assert ("next", "history", "prompt") in d
+ break
+ return
if __name__ == "__main__":
diff --git a/test/llm/test_modules.py b/test/llm/test_modules.py
deleted file mode 100644
index faaa1ff3fa2..00000000000
--- a/test/llm/test_modules.py
+++ /dev/null
@@ -1,616 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-#
-# This source code is licensed under the MIT license found in the
-# LICENSE file in the root directory of this source tree.
-from __future__ import annotations
-
-import argparse
-import importlib.util
-
-import pytest
-import torch
-
-from mocking_classes_llm import DummyStrDataLoader
-from tensordict import (
- lazy_stack,
- LazyStackedTensorDict,
- NonTensorStack,
- set_list_to_stack,
- TensorDict,
-)
-from torchrl.collectors.llm import LLMCollector
-from torchrl.data.llm import LLMData
-from torchrl.envs.llm import LLMEnv
-from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
-from transformers import OPTForCausalLM
-
-_has_transformers = importlib.util.find_spec("transformers")
-_has_vllm = importlib.util.find_spec("vllm")
-
-
-@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies")
-@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies")
-class TestLLMActor:
- @pytest.fixture(scope="module")
- def vllm_instance(self):
- try:
- import vllm
- except ImportError:
- pytest.skip(reason="missing vllm")
-
- llm_model = vllm.LLM("facebook/opt-125m")
- tokenizer = llm_model.get_tokenizer()
- tokenizer.pad_token = tokenizer.eos_token
- return llm_model
-
- @pytest.fixture(scope="module")
- def transformers_instance(self):
- from transformers import AutoTokenizer
-
- # tokenizer = AutoTokenizer.from_pretrained("gpt2")
- # model = GPT2LMHeadModel(GPT2Config()).eval()
- # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
- # model = OPTModel(OPTConfig("facebook/opt-125m"))
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
- model = OPTForCausalLM.from_pretrained("facebook/opt-125m")
-
- tokenizer.pad_token = tokenizer.eos_token
- tokenizer.padding_side = "left"
-
- return model, tokenizer
-
- @pytest.fixture(scope="module")
- def transformers_instance_pretrained(self):
- from transformers import AutoTokenizer, OPTForCausalLM
-
- # tokenizer = AutoTokenizer.from_pretrained("gpt2")
- # model = GPT2LMHeadModel(GPT2Config())
- # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
- # model = OPTModel(OPTConfig("facebook/opt-125m"))
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
- model = OPTForCausalLM.from_pretrained("facebook/opt-125m")
-
- tokenizer.pad_token = tokenizer.eos_token
- tokenizer.padding_side = "left"
-
- return model, tokenizer
-
- @pytest.mark.parametrize(
- "from_text, generate, return_log_probs, tokens, attention_mask",
- [
- (True, True, True, None, None),
- (True, True, False, None, None),
- (True, False, None, None, None),
- (
- False,
- True,
- True,
- torch.randint(1024, (1, 10)),
- torch.ones(1, 10, dtype=torch.int64),
- ),
- (False, True, True, torch.randint(1024, (1, 10)), None),
- (
- False,
- True,
- False,
- torch.randint(1024, (1, 10)),
- torch.ones(1, 10, dtype=torch.int64),
- ),
- (False, True, False, torch.randint(1024, (1, 10)), None),
- ],
- )
- def test_transformers_wrapper(
- self,
- from_text,
- generate,
- return_log_probs,
- tokens,
- attention_mask,
- transformers_instance,
- ):
- torch.manual_seed(0)
-
- model, tokenizer = transformers_instance
-
- m = TransformersWrapper(
- model,
- tokenizer=tokenizer,
- from_text=from_text,
- generate=generate,
- return_log_probs=return_log_probs,
- )
- self._run_check(
- m,
- tokens,
- attention_mask,
- generate,
- return_log_probs,
- from_text,
- has_logits=True,
- )
-
- @pytest.mark.skip_if_nightly
- @pytest.mark.parametrize(
- "from_text, generate, return_log_probs, tokens, attention_mask",
- [
- (True, True, True, None, None),
- (True, True, False, None, None),
- (True, False, None, None, None),
- (
- False,
- True,
- True,
- torch.randint(1024, (1, 10)),
- torch.ones(1, 10, dtype=torch.int64),
- ),
- (False, True, True, torch.randint(1024, (1, 10)), None),
- (
- False,
- True,
- False,
- torch.randint(1024, (1, 10)),
- torch.ones(1, 10, dtype=torch.int64),
- ),
- (False, True, False, torch.randint(1024, (1, 10)), None),
- ],
- )
- def test_vllm_wrapper(
- self,
- from_text,
- generate,
- return_log_probs,
- tokens,
- attention_mask,
- vllm_instance,
- ):
- torch.manual_seed(0)
-
- model = vllm_instance
- m = vLLMWrapper(
- model,
- from_text=from_text,
- generate=generate,
- return_log_probs=return_log_probs,
- )
- self._run_check(
- m,
- tokens,
- attention_mask,
- generate,
- return_log_probs,
- from_text,
- has_logits=False,
- )
-
- def _make_data(
- self,
- m,
- tokens,
- attention_mask,
- generate,
- from_text,
- has_logits,
- batch_size=1,
- text_response=None,
- tokens_response=None,
- ):
- lp_kwargs = {}
- if from_text:
- if not generate:
- text_response = (
- NonTensorStack(" and another text that follows")
- if text_response is None
- else text_response
- )
- if not isinstance(text_response, NonTensorStack):
- if isinstance(text_response, list):
- text_response = NonTensorStack(*text_response)
- else:
- text_response = NonTensorStack(text_response)
- lp_kwargs.update({"text_response": text_response})
- tdin = LLMData(
- text=NonTensorStack("Somewhere, I lost"),
- **lp_kwargs,
- batch_size=batch_size,
- )
- else:
- if not generate:
- if tokens_response is None:
- shape_response = tokens.shape
- shape_response = shape_response[:-1] + (shape_response[-1] * 2,)
- tokens_response = torch.randint(1024, shape_response)
- lp_kwargs.update({"tokens_response": tokens_response})
- tdin = LLMData(
- tokens=tokens,
- attention_mask=attention_mask,
- **lp_kwargs,
- batch_size=batch_size,
- )
- return tdin
-
- def _run_check(
- self,
- m,
- tokens,
- attention_mask,
- generate,
- return_log_probs,
- from_text,
- has_logits,
- ):
- tdin = self._make_data(
- m, tokens, attention_mask, generate, from_text, has_logits
- )
- if from_text and generate:
- assert tdin.text_response is None
- elif from_text and not generate:
- assert tdin.text_response is not None
-
- tdin.copy()
- td = m(tdin)
- assert td is tdin
- assert isinstance(td, LLMData)
- if from_text and generate:
- assert td.text_response is not None
-
- # TODO: vLLM may produce an attention mask when hf does not - explore consistency!
- # if generate and (from_text or tdincopy.attention_mask is not None):
- # assert td.attention_mask is not None, (generate, from_text, tdincopy.attention_mask is not None)
- # if isinstance(td.attention_mask, torch.Tensor):
- # assert td.attention_mask.shape == td.tokens.shape
- # else:
- # assert td.attention_mask is None, (generate, from_text)
-
- if not generate:
- # logprobs are computed on text response of tokens_response
- assert td.text_response is not None or td.tokens_response is not None
- assert td.log_probs is not None
- if has_logits:
- assert td.logits is not None
- if generate:
- if return_log_probs:
- assert td.log_probs is not None
- assert td.log_probs.shape[-1] == td.tokens_response.shape[-1]
- else:
- assert td.log_probs is None
-
- # Test the shapes
- assert td.tokens_response is not None, (generate, has_logits, from_text)
-
- # If from text and not generating, the tokens are not returned for now
- if not (from_text and not generate):
- assert td.tokens_response is not None
- assert td.tokens is not None
- assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1]
- # The convention is that the response only has new tokens
- assert (
- td.tokens_response[..., : td.tokens.shape[-1]]
- != td.tokens[..., : td.tokens_response.shape[-1]]
- ).any(), (generate, from_text)
-
- @pytest.mark.parametrize(
- "from_text, tokens, attention_mask",
- [
- (
- False,
- torch.randint(1024, (1, 10)),
- torch.ones(1, 10, dtype=torch.int64),
- ),
- (False, torch.randint(1024, (1, 10)), None),
- (True, None, None),
- ],
- )
- def test_transformers_logprobs(
- self, from_text, tokens, attention_mask, transformers_instance
- ):
- torch.manual_seed(0)
- model, tokenizer = transformers_instance
-
- m_generate = TransformersWrapper(
- model,
- tokenizer=tokenizer,
- from_text=from_text,
- generate=True,
- return_log_probs=True,
- )
- m_logprobs = TransformersWrapper(
- model, tokenizer=tokenizer, from_text=from_text, generate=False
- )
- self._check_lps(
- m_generate,
- m_logprobs,
- tokens,
- attention_mask,
- from_text,
- has_logits=False,
- )
-
- @pytest.mark.skip_if_nightly
- @pytest.mark.parametrize(
- "pad_output, from_text, tokens, attention_mask",
- [
- (True, True, None, None),
- (False, True, None, None),
- (
- True,
- False,
- torch.randint(1024, (1, 10)),
- torch.ones(1, 10, dtype=torch.int64),
- ),
- (True, False, torch.randint(1024, (1, 10)), None),
- ],
- )
- def test_vllm_logprobs(
- self, from_text, tokens, attention_mask, pad_output, vllm_instance
- ):
- torch.manual_seed(0)
-
- model = vllm_instance
- m_generate = vLLMWrapper(
- model,
- from_text=from_text,
- generate=True,
- return_log_probs=True,
- pad_output=pad_output,
- )
- m_logprobs = vLLMWrapper(
- model, from_text=from_text, generate=False, pad_output=pad_output
- )
- self._check_lps(
- m_generate,
- m_logprobs,
- tokens,
- attention_mask,
- from_text,
- has_logits=False,
- tol=1e-1,
- )
-
- def _check_lps(
- self,
- model_generate,
- model_logprobs,
- tokens,
- attention_mask,
- from_text,
- has_logits,
- tol=1e-2,
- ):
- # Checks that the log-probs gathered with generate=False equate those with generate=True
- tdin_genetate = self._make_data(
- model_generate, tokens, attention_mask, True, from_text, has_logits
- )
- td_generate = model_generate(tdin_genetate)
- tdin_logprobs = self._make_data(
- model_logprobs,
- tokens,
- attention_mask,
- False,
- from_text,
- has_logits,
- tokens_response=td_generate.tokens_response,
- text_response=td_generate.text_response,
- )
- td_logprobs = model_logprobs(tdin_logprobs)
- assert td_generate.tokens_response.shape == td_logprobs.tokens_response.shape
- assert (td_generate.tokens_response == td_logprobs.tokens_response).all(), (
- td_generate.tokens_response == td_logprobs.tokens_response
- )
- assert td_generate.log_probs.shape == td_generate.tokens_response.shape
- assert td_logprobs.log_probs.shape == td_logprobs.tokens_response.shape
- assert td_logprobs.log_probs.shape == td_generate.tokens_response.shape
- torch.testing.assert_close(
- td_generate.log_probs, td_logprobs.log_probs, rtol=tol, atol=tol
- )
-
- @pytest.mark.skip_if_nightly
- @pytest.mark.parametrize("pad", [True, False])
- @pytest.mark.parametrize("generate", [True, False])
- @pytest.mark.parametrize("use_tensorclass", [True, False])
- def test_vllm_batch_run(self, pad, generate, use_tensorclass, vllm_instance):
- # Test generate - padding combinations
- policy = vLLMWrapper(
- vllm_instance,
- from_text=True,
- generate=generate,
- return_log_probs=True,
- pad_output=pad,
- generate_kwargs={"max_tokens": 10000},
- )
- if generate:
- data = LazyStackedTensorDict(
- *TensorDict(
- text=NonTensorStack("a string", "another very long string"),
- batch_size=[2],
- ).unbind(0)
- )
- else:
- data = LazyStackedTensorDict(
- *TensorDict(
- text=NonTensorStack("a string", "another very long string"),
- text_response=NonTensorStack(
- " is a string", " is still a very long string"
- ),
- batch_size=[2],
- ).unbind(0)
- )
- if use_tensorclass:
- data = LLMData.from_tensordict(data)
- output = policy(data)
- try:
- log_probs = output.get("log_probs")
- except Exception:
- log_probs = output.get("log_probs", as_list=True)
- if pad:
- assert isinstance(log_probs, torch.Tensor)
- else:
- assert isinstance(log_probs, list)
- text = output.get("text", as_list=True)
- # TODO: this is not ideal...
- if use_tensorclass:
- assert isinstance(text, list)
- else:
- assert isinstance(text, NonTensorStack)
- text_response = output.get("text_response", as_list=True)
- if use_tensorclass:
- assert isinstance(text_response, list)
- else:
- assert isinstance(text_response, NonTensorStack)
- try:
- tokens_response = output.get("tokens_response")
- except Exception:
- tokens_response = output.get("tokens_response", as_list=True)
- if pad:
- assert isinstance(tokens_response, torch.Tensor)
- else:
- assert isinstance(tokens_response, list)
- try:
- tokens = output.get("tokens")
- except Exception:
- tokens = output.get("tokens", as_list=True)
- if not generate:
- assert tokens is None
- elif pad:
- assert isinstance(tokens, torch.Tensor), tokens
- else:
- assert isinstance(tokens, list)
-
- @pytest.mark.skip_if_nightly
- @pytest.mark.parametrize("from_text", [True])
- def test_vllm_collection(self, vllm_instance, from_text):
- policy = vLLMWrapper(
- vllm_instance,
- return_log_probs=True,
- generate_kwargs={"max_tokens": 32},
- from_text=from_text in (True, None),
- )
- tokenizer = vllm_instance.get_tokenizer()
- self._run_check_collector(policy, from_text=from_text, tokenizer=tokenizer)
-
- def test_transformers_collection(self):
- ...
-
- @classmethod
- def env_constructor(cls, **kwargs):
- def make():
- # if kwargs.get("from_text", True):
- dl = DummyStrDataLoader(batch_size=32)
- # else:
- # dl = DummyTensorDataLoader(batch_size=32)
- env = LLMEnv.from_dataloader(
- dl,
- batch_size=4,
- repeats=4,
- **kwargs,
- )
- assert env.batch_size == (16,)
- return env
-
- return make
-
- def _run_check_collector(self, policy, from_text, tokenizer):
- if from_text is None:
- kwargs = {"eos_token_id": tokenizer.eos_token_id}
- else:
- kwargs = {
- "from_text": from_text,
- "tokenizer": tokenizer,
- "eos_token_id": tokenizer.eos_token_id,
- }
- collector = LLMCollector(
- self.env_constructor(**kwargs),
- policy=policy,
- dialog_turns_per_batch=32,
- total_dialog_turns=128,
- )
- t = 0
- for data in collector:
- assert isinstance(data, LazyStackedTensorDict)
- assert isinstance(data.reshape(-1).get("text_response"), NonTensorStack)
- # action
- assert "text_response" in data
- assert "tokens_response" in data
- # obs
- assert "text" in data
- assert ("next", "text") in data
- # tokens
- assert "tokens" in data
-
- t += data.numel()
- assert collector._frames == t
- assert t < 512, t # assert ("next", "tokens") in data
-
- @pytest.mark.skip_if_nightly
- def test_vllm_generate_multiple_trajs(self, vllm_instance):
- policy = vLLMWrapper(
- vllm_instance,
- return_log_probs=True,
- generate_kwargs={"n": 10, "max_tokens": 1024},
- inplace=False,
- )
- data = TensorDict(
- text=NonTensorStack("a string", "another very long string"), batch_size=2
- )
- data = policy(data)
-
- @set_list_to_stack(True)
- @pytest.mark.parametrize("from_text", [True, False])
- @pytest.mark.parametrize("generate", [True, False])
- def test_transformers_long_sequences(
- self, from_text, generate, transformers_instance_pretrained
- ):
- torch.manual_seed(42)
- model, tokenizer = transformers_instance_pretrained
- prompts = [
- "The quick brown fox jumps over the lazy dog.", # Likely to finish soon
- "Once upon a time in a land far, far away, there was a", # Likely to continue longer
- "In the beginning, the universe was created. This has made a lot of people very angry and been widely regarded as a bad move.",
- ]
- data = lazy_stack([TensorDict() for _ in range(len(prompts))])
- data["text"] = prompts
- eos_token_id = tokenizer.convert_tokens_to_ids(",")
- if not from_text:
- data["tokens"] = tokenizer(data["text"])["input_ids"]
- data["attention_mask"] = (
- 0 * data.get("tokens", as_nested_tensor=True, layout=torch.strided) + 1
- )
- if not generate:
- # we need responses
- responses = prompts[1:] + [" et dolore magna aliqua."]
- data["text_response"] = responses
- if not from_text:
- data["tokens_response"] = tokenizer(data["text_response"])["input_ids"]
- # make sure dimensions are ragged for tokens entries
- if "tokens" in data:
- assert data.get_item_shape("tokens")[-1] == -1
- if "tokens_response" in data:
- assert data.get_item_shape("tokens_response")[-1] == -1
- generate_kwargs = {}
- if generate:
- generate_kwargs = {
- "max_new_tokens": 128, # Set a reasonable number of new tokens to generate
- "min_length": 20, # Ensure a minimum length for the generated sequence
- "pad_token_id": tokenizer.pad_token_id, # Use the tokenizer's pad token
- "forced_eos_token_id": eos_token_id, # Use comma as an EOS token
- }
- policy = TransformersWrapper(
- model,
- tokenizer=tokenizer,
- from_text=from_text,
- generate=generate,
- return_log_probs=True,
- # TODO: use n trajs
- generate_kwargs=generate_kwargs,
- )
- data_policy = policy(data)
- if "tokens" in data_policy:
- assert data_policy.get_item_shape("tokens")[-1] == -1
- if "tokens_response" in data_policy:
- assert (
- data_policy.get_item_shape("tokens_response")[-1] == -1
- ) # TODO: this fails
-
-
-if __name__ == "__main__":
- args, unknown = argparse.ArgumentParser().parse_known_args()
- pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/test/llm/test_objectives.py b/test/llm/test_objectives.py
index baf301e5f33..9dd0ffb9367 100644
--- a/test/llm/test_objectives.py
+++ b/test/llm/test_objectives.py
@@ -10,19 +10,17 @@
import numpy as np
import pytest
import torch
-from mocking_classes_llm import DummyStrDataLoader
-from tensordict import lazy_stack, set_capture_non_tensor_stack, TensorDict
-from torchrl.data import History, LazyStackStorage, ReplayBuffer, Unbounded
-from torchrl.envs import Transform
-from torchrl.envs.llm import LLMEnv
+from tensordict import lazy_stack, TensorDict
+from torchrl.data import History, LazyStackStorage, ReplayBuffer
from torchrl.envs.llm.transforms.kl import RetrieveLogProb
-from torchrl.modules.llm import TransformersWrapper
-from torchrl.objectives import ClipPPOLoss
-from torchrl.objectives.llm.grpo import GRPOLoss, GRPOLossOutput, MCAdvantage
+from torchrl.modules.llm import Text, TransformersWrapper, vLLMWrapper
+from torchrl.modules.llm.policies.common import ChatHistory, Masks, Tokens
+from torchrl.objectives.llm.grpo import MCAdvantage
from torchrl.objectives.llm.sft import SFTLoss
_has_transformers = importlib.util.find_spec("transformers") is not None
+_has_vllm = importlib.util.find_spec("vllm") is not None
prompts = [
"Lorem ipsum dolor sit amet,",
"consectetur adipiscing elit,",
@@ -55,7 +53,7 @@ def make_silly_trajectory(n_steps=None):
rewards = [torch.randn(n_tokens, 1)]
prompt = np.random.choice(prompts)
td = TensorDict(
- text=prompt,
+ text=Text(prompt=prompt),
next=TensorDict(
reward=rewards, done=torch.zeros(1, dtype=torch.bool)
),
@@ -89,80 +87,6 @@ def test_grpo():
...
-class TestPPO4LLMs:
- @pytest.mark.skipif(
- not _has_transformers, reason="transformers lib required to test PPO with LLMs"
- )
- @set_capture_non_tensor_stack(False)
- @pytest.mark.parametrize("from_text", [True, False])
- @pytest.mark.parametrize("cls", [ClipPPOLoss, GRPOLoss])
- def test_hf(self, from_text, cls):
- from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
-
- tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
- tokenizer.pad_token = tokenizer.eos_token
-
- model = OPTForCausalLM(OPTConfig()).eval()
- policy_inference = TransformersWrapper(
- model,
- tokenizer=tokenizer,
- generate=True,
- from_text=from_text,
- return_log_probs=True,
- )
- policy_train = TransformersWrapper(
- model, tokenizer=tokenizer, generate=False, from_text=False
- )
- for p in policy_train.parameters():
- assert p.requires_grad
- # Create some fake data
- dl = DummyStrDataLoader(batch_size=32)
- llm_env = LLMEnv.from_dataloader(
- dl,
- tokenizer=tokenizer if not from_text else None,
- batch_size=(32,),
- from_text=True,
- eos_token_id=tokenizer.eos_token_id,
- )
-
- class RewardTransform(Transform):
- def _step(self, td, next_td):
- next_td["reward"] = torch.randn_like(
- td["tokens_response"], dtype=torch.float
- ).unsqueeze(-1)
- return next_td
-
- def transform_reward_spec(self, reward_spec):
- return reward_spec.set(
- "reward", Unbounded((*reward_spec.shape, -1, 1), dtype=torch.float)
- )
-
- llm_env = llm_env.append_transform(RewardTransform())
- with torch.no_grad():
- data = llm_env.rollout(3, policy_inference)
- data = data.view(-1)
- assert data["tokens_response"].shape[-1] == 20
- # Make some fake advantages:
- data["advantage"] = torch.randn_like(data["next", "reward"])
-
- loss = cls(
- actor_network=policy_train,
- )
- loss_vals = loss(data)
- if cls is ClipPPOLoss:
- assert "loss_objective" in loss_vals
- assert "loss_entropy" in loss_vals
- assert loss_vals["loss_objective"].requires_grad
- assert loss_vals["loss_entropy"].requires_grad
- assert "clip_fraction" in loss_vals
- assert "kl_approx" in loss_vals
- assert "entropy" in loss_vals
- assert "ESS" in loss_vals
- assert "loss_critic" not in loss_vals
- else:
- assert isinstance(loss_vals, GRPOLossOutput)
-
-
class TestSFT:
@pytest.fixture(scope="class")
def data(self):
@@ -190,20 +114,21 @@ def data(self):
text = history[:, :-1].apply_chat_template(
tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
)
- text_response = history.apply_chat_template(
+ full_text = history.apply_chat_template(
tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
)
text_response = [
- txt[len(txt_start) :] for txt, txt_start in zip(text_response, text)
+ txt[len(txt_start) :] for txt, txt_start in zip(full_text, text)
]
td = TensorDict(
- text=text,
- text_response=text_response,
- history=history,
+ text=Text(prompt=text, response=text_response, full=full_text),
+ history=ChatHistory(
+ full=history, prompt=history[..., :-1], response=history[..., -1:]
+ ),
next=TensorDict(
reward=torch.randn(2, 1),
done=torch.zeros(2, dtype=torch.bool),
- history=history,
+ history=ChatHistory(prompt=history),
),
batch_size=(2,),
)
@@ -227,8 +152,9 @@ def policy_train(self):
model,
tokenizer=tokenizer,
generate=False,
- from_text=True,
chat_template_name="qwen",
+ input_mode="history",
+ pad_output=False,
)
return policy_train, tokenizer
@@ -249,8 +175,6 @@ def test_sft(
data,
policy_train,
):
- pass
-
policy_train, tokenizer = policy_train
loss = SFTLoss(
actor_network=policy_train,
@@ -269,20 +193,21 @@ def test_sft(
policy_train.model,
tokenizer=tokenizer,
generate=False,
- from_text=True,
return_log_probs=True,
chat_template_name="qwen",
+ input_mode="history",
+ pad_output=False,
)
transform = RetrieveLogProb(
policy_ref,
assistant_only=True,
tokenizer_kwargs={"chat_template_name": "qwen"},
tokenizer=tokenizer,
+ log_probs_key=("ref_log_prob", "full"),
)
with torch.no_grad():
# Compute ref log-probs
transform(td)
-
loss_vals = loss(td)
if kl_to_ref_coeff is not None and loss_function != "minor_sft":
assert loss_vals.loss_kl_to_ref.shape == ()
@@ -296,7 +221,7 @@ def test_sft(
assert loss_vals.sum(reduce=True).shape == ()
def test_sft_assistant_only(self, data):
- from torchrl.data.llm.chat import _CHAT_TEMPLATES
+ from torchrl.data.llm.history import _CHAT_TEMPLATES
from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
@@ -308,14 +233,12 @@ def test_sft_assistant_only(self, data):
model,
tokenizer=tokenizer,
generate=False,
- from_text=True,
chat_template_name="qwen",
)
policy_ref = TransformersWrapper(
model,
tokenizer=tokenizer,
generate=False,
- from_text=True,
return_log_probs=True,
chat_template_name="qwen",
)
@@ -324,6 +247,7 @@ def test_sft_assistant_only(self, data):
assistant_only=True,
tokenizer_kwargs={"chat_template_name": "qwen"},
tokenizer=tokenizer,
+ log_probs_key=("ref_log_prob", "full"),
)
td = transform(data)
assert td is data
@@ -338,6 +262,181 @@ def test_sft_assistant_only(self, data):
loss(td)
+class TestGRPOLossIntegration:
+ """Test GRPOLoss integration with the new distribution methods."""
+
+ @pytest.fixture(scope="module")
+ def transformers_instance(self):
+ """Create transformers model and tokenizer for testing."""
+ if not _has_transformers:
+ pytest.skip("transformers not available")
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+ tokenizer.pad_token = tokenizer.eos_token
+ return model, tokenizer
+
+ @pytest.fixture(scope="module")
+ def vllm_instance(self):
+ """Create vLLM model and tokenizer for testing."""
+ if not _has_vllm:
+ pytest.skip("vllm not available")
+
+ import vllm.envs as envs
+ from transformers import AutoTokenizer
+ from vllm import LLM
+
+ envs.VLLM_HOST_IP = "0.0.0.0" or "127.0.0.1"
+
+ try:
+ model = LLM("Qwen/Qwen2.5-0.5B")
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+ tokenizer.pad_token = tokenizer.eos_token
+ return model, tokenizer
+ except Exception as e:
+ pytest.skip(f"Failed to load vLLM model: {e}")
+
+ @pytest.fixture(scope="module")
+ def sample_tokens(self, vllm_instance):
+ """Create sample tokens for testing."""
+ model, tokenizer = vllm_instance
+ text = [
+ "Are you happy? Say yes or no.",
+ "Explain the difference between a cat and a dog. Be very detailed.",
+ ]
+ tokenized = tokenizer(
+ text, return_tensors="pt", padding=True, padding_side="left"
+ )
+ return tokenized["input_ids"], tokenized["attention_mask"]
+
+ @pytest.fixture(scope="module")
+ def sample_text(self):
+ """Create sample text for testing."""
+ return [
+ "Are you happy? Say yes or no.",
+ "Explain the difference between a cat and a dog. Be very detailed.",
+ ]
+
+ @pytest.fixture(scope="module")
+ def sample_history(self):
+ """Create sample conversation history for testing."""
+ chats = [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Are you happy? Say yes or no."},
+ ],
+ [
+ {
+ "role": "system",
+ "content": "You are a very helpful assistant, but more handsome.",
+ },
+ {
+ "role": "user",
+ "content": "Explain the difference between a cat and a dog. Be very detailed.",
+ },
+ ],
+ ]
+ return History.from_chats(chats)
+
+ @pytest.fixture(scope="module")
+ def sample_history_assistant(self):
+ """Create sample conversation history for testing."""
+ chats = [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Are you happy? Say yes or no."},
+ {"role": "assistant", "content": "Yes."},
+ ],
+ [
+ {
+ "role": "system",
+ "content": "You are a very helpful assistant, but more handsome.",
+ },
+ {
+ "role": "user",
+ "content": "Explain the difference between a cat and a dog. Be very detailed.",
+ },
+ {
+ "role": "assistant",
+ "content": "A cat is a small animal that meows, while a dog is a larger animal that barks.",
+ },
+ ],
+ ]
+ return History.from_chats(chats)
+
+ @pytest.mark.skipif(not _has_vllm, reason="vllm not available")
+ @pytest.mark.parametrize("masking_strategy", ["sft", "rlhf"])
+ def test_grpo_loss_with_transformers(
+ self,
+ vllm_instance,
+ transformers_instance,
+ sample_history,
+ sample_tokens,
+ masking_strategy,
+ ):
+ """Test GRPOLoss with vLLM wrapper and different masking strategies."""
+ from torchrl.objectives.llm.grpo import GRPOLoss
+
+ model, tokenizer = transformers_instance
+ vllm_model, vllm_tokenizer = vllm_instance
+
+ # Use tokens input mode for SFT, history for RLHF/generic
+ if masking_strategy == "sft":
+ input_mode = "tokens"
+ input_ids, attention_mask = sample_tokens
+ input_data = {
+ "tokens": Tokens(prompt=input_ids),
+ "masks": Masks(all_attention_mask=attention_mask),
+ }
+ else:
+ input_mode = "history"
+ input_data = {"history": ChatHistory(prompt=sample_history)}
+
+ wrapper_gen = vLLMWrapper(
+ vllm_model,
+ tokenizer=vllm_tokenizer,
+ input_mode=input_mode,
+ generate=True,
+ return_log_probs=True,
+ pad_output=True,
+ generate_kwargs={"max_tokens": 10},
+ )
+
+ # Create test data with advantage and correct batch size
+ td = TensorDict(input_data, batch_size=(2,)).to_lazystack(0)
+ td = wrapper_gen(td)
+ # use a shape that can be broadcast
+ td["advantage"] = torch.randn(2, 1, 1)
+
+ wrapper = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode=input_mode,
+ generate=False,
+ return_log_probs=True,
+ pad_output=True,
+ )
+
+ # Create GRPOLoss with specified masking strategy
+ loss_fn = GRPOLoss(
+ actor_network=wrapper,
+ masking_strategy=masking_strategy,
+ )
+
+ # This should work without shape mismatch errors
+ try:
+ result = loss_fn(td)
+ assert result is not None
+ except ValueError as e:
+ if "Shape mismatch" in str(e):
+ # This is expected if the advantage shape doesn't match the log-prob shape
+ # due to different masking strategies
+ assert masking_strategy in str(e)
+ else:
+ raise
+
+
if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/test/llm/test_wrapper.py b/test/llm/test_wrapper.py
new file mode 100644
index 00000000000..49197952972
--- /dev/null
+++ b/test/llm/test_wrapper.py
@@ -0,0 +1,1710 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+import argparse
+import importlib.util
+
+import os
+from functools import partial
+
+import pytest
+import torch
+from tensordict import lazy_stack, set_list_to_stack, TensorDict
+
+from tensordict.utils import _zip_strict
+from torchrl.data.llm import History
+from torchrl.envs.llm.transforms.kl import KLComputation, RetrieveKL, RetrieveLogProb
+from torchrl.modules.llm.policies.common import (
+ ChatHistory,
+ LogProbs,
+ Masks,
+ Text,
+ Tokens,
+)
+from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper
+from torchrl.modules.llm.policies.vllm_wrapper import vLLMWrapper
+from transformers import AutoTokenizer
+
+
+# Set environment variable for vLLM V0 engine
+os.environ["VLLM_USE_V1"] = "0"
+
+_has_transformers = importlib.util.find_spec("transformers") is not None
+_has_vllm = importlib.util.find_spec("vllm") is not None
+_has_datasets = importlib.util.find_spec("datasets") is not None
+
+TransformersWrapperMaxTokens = partial(
+ TransformersWrapper, generate_kwargs={"max_new_tokens": 10, "do_sample": True}
+)
+
+
+@pytest.fixture(scope="function", autouse=True)
+def set_seed():
+ torch.manual_seed(0)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(0)
+ torch.cuda.manual_seed_all(0)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ yield
+
+
+@pytest.fixture(scope="module", autouse=True)
+def set_list_to_stack_fixture():
+ with set_list_to_stack(True):
+ yield
+
+
+@pytest.fixture(scope="module")
+def vllm_instance():
+ """Create vLLM model and tokenizer for testing."""
+ if not _has_vllm:
+ pytest.skip("vllm not available")
+
+ import vllm.envs as envs
+ from vllm import LLM
+
+ envs.VLLM_HOST_IP = "0.0.0.0" or "127.0.0.1"
+
+ assert os.environ.get("VLLM_USE_V1") == "0"
+
+ try:
+ model = LLM("Qwen/Qwen2.5-0.5B")
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+ tokenizer.pad_token = tokenizer.eos_token
+ return model, tokenizer
+ except Exception as e:
+ pytest.skip(f"Failed to load vLLM model: {e}")
+
+
+@pytest.fixture(scope="module")
+def transformers_instance():
+ """Create transformers model and tokenizer for testing."""
+ if not _has_transformers:
+ pytest.skip("transformers not available")
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+ tokenizer.pad_token = tokenizer.eos_token
+ return model, tokenizer
+
+
+@pytest.fixture
+def sample_history():
+ """Create sample conversation history for testing."""
+ chats = [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Are you happy? Say yes or no."},
+ ],
+ [
+ {
+ "role": "system",
+ "content": "You are a very helpful assistant, but more handsome.",
+ },
+ {
+ "role": "user",
+ "content": "Explain the difference between a cat and a dog. Be very detailed.",
+ },
+ ],
+ ]
+ return History.from_chats(chats)
+
+
+@pytest.fixture
+def sample_history_assistant():
+ """Create sample conversation history for testing."""
+ chats = [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Are you happy? Say yes or no."},
+ {"role": "assistant", "content": "Yes."},
+ ],
+ [
+ {
+ "role": "system",
+ "content": "You are a very helpful assistant, but more handsome.",
+ },
+ {
+ "role": "user",
+ "content": "Explain the difference between a cat and a dog. Be very detailed.",
+ },
+ {
+ "role": "assistant",
+ "content": "A cat is a small animal that meows, while a dog is a larger animal that barks.",
+ },
+ ],
+ ]
+ return History.from_chats(chats)
+
+
+@pytest.fixture
+def sample_text():
+ """Create sample text for testing."""
+ return [
+ "Are you happy? Say yes or no.",
+ "Explain the difference between a cat and a dog. Be very detailed.",
+ ]
+
+
+@pytest.fixture
+def sample_tokens(vllm_instance):
+ """Create sample tokens for testing."""
+ model, tokenizer = vllm_instance
+ text = [
+ "Are you happy? Say yes or no.",
+ "Explain the difference between a cat and a dog. Be very detailed.",
+ ]
+ tokenized = tokenizer(text, return_tensors="pt", padding=True, padding_side="left")
+ return tokenized["input_ids"], tokenized["attention_mask"]
+
+
+def check_output_shapes(out, pad_output, requested_log_probs=False):
+ if pad_output:
+ # We can get all tensors or they are none
+ log_probs = out.get("log_probs")
+ masks = out.get("masks")
+ tokens = out.get("tokens")
+ text = out.get("text")
+ history = out.get("history")
+
+ # Test the all_ tensors
+ if log_probs is not None:
+ assert isinstance(log_probs, LogProbs)
+ all_logprobs = log_probs.full
+ else:
+ all_logprobs = None
+ if masks is not None:
+ assert isinstance(masks, Masks)
+ all_attention_masks = masks.all_attention_mask
+ all_assistant_masks = masks.all_assistant_mask
+ else:
+ all_attention_masks = None
+ all_assistant_masks = None
+ if tokens is not None:
+ assert isinstance(tokens, Tokens)
+ all_tokens = tokens.full
+ else:
+ all_tokens = None
+ if text is not None:
+ assert isinstance(text, Text)
+ text.full
+ else:
+ pass
+ if history is not None:
+ assert isinstance(history, ChatHistory)
+ history.full
+ else:
+ pass
+
+ shapes = set()
+ if all_logprobs is not None:
+ shapes.add(all_logprobs.shape)
+ if all_attention_masks is not None:
+ shapes.add(all_attention_masks.shape)
+ if all_assistant_masks is not None:
+ shapes.add(all_assistant_masks.shape)
+ if all_tokens is not None:
+ shapes.add(all_tokens.shape)
+ assert len(shapes) <= 1, ("all_tensors shapes differ", out)
+
+ # Check the response tensors
+ shapes = set()
+ if log_probs is not None and log_probs.response is not None:
+ shapes.add(log_probs.response.shape)
+ if tokens is not None and tokens.response is not None:
+ shapes.add(tokens.response.shape)
+ assert len(shapes) <= 1, (shapes, out)
+
+ # Check the prompt tensors
+ shapes = set()
+ if log_probs is not None and log_probs.prompt is not None:
+ shapes.add(log_probs.prompt.shape)
+ if tokens is not None and tokens.prompt is not None:
+ shapes.add(tokens.prompt.shape)
+
+ if (
+ log_probs is not None
+ and log_probs.response is not None
+ and log_probs.prompt is not None
+ ):
+ assert (
+ log_probs.response.shape[-1] + log_probs.prompt.shape[-1]
+ == log_probs.full.shape[-1]
+ )
+ if (
+ tokens is not None
+ and tokens.response is not None
+ and tokens.prompt is not None
+ ):
+ assert (
+ tokens.response.shape[-1] + tokens.prompt.shape[-1]
+ == tokens.full.shape[-1]
+ )
+
+ assert len(shapes) <= 1, shapes
+
+ # Check that if 'full' is defined, either both 'prompt' and 'response' must be set or neither of them
+ if requested_log_probs:
+ for obj_name, obj in [
+ ("log_probs", log_probs),
+ ("tokens", tokens),
+ ("text", text),
+ ]:
+ if obj is not None and obj.get("full", as_list=True) is not None:
+ has_prompt = obj.get("prompt", as_list=True) is not None
+ has_response = obj.get("response", as_list=True) is not None
+ assert (has_prompt and has_response) or (
+ not has_prompt and not has_response
+ ), (
+ f"{obj_name}: if 'full' is defined, either both 'prompt' and 'response' must be set or neither of them. "
+ f"prompt={has_prompt}, response={has_response}, full={obj.full is not None}"
+ )
+ else:
+ # we can simply iterate over out
+ for _out in out.unbind(0):
+ check_output_shapes(
+ _out, pad_output=not _out.ndim, requested_log_probs=requested_log_probs
+ )
+
+
+@pytest.mark.skipif(not _has_vllm, reason="vllm not available")
+class TestWrappers:
+ """Comprehensive tests for vLLMWrapper and TransformersWrapper covering all modalities and configurations."""
+
+ # ================================================
+ # History Input Mode Tests
+ # ================================================
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ @pytest.mark.parametrize("generate", [True, False], ids=["generate", "no_generate"])
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ def test_history_input_mode(
+ self,
+ wrapper_class,
+ vllm_instance,
+ transformers_instance,
+ sample_history,
+ sample_history_assistant,
+ generate,
+ pad_output,
+ ):
+ """Test history input mode with various configurations."""
+
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=generate,
+ pad_output=pad_output,
+ )
+
+ # Check input keys
+ assert (
+ wrapper.in_keys == [("history", "prompt")]
+ if generate
+ else [("history", "full")]
+ )
+
+ # Check output keys - always return everything
+ expected_out_keys = ["text", "masks", "tokens", "log_probs", "history"]
+ assert wrapper.out_keys == expected_out_keys
+
+ # Create input data
+ if generate:
+ data = TensorDict(
+ history=ChatHistory(prompt=sample_history),
+ batch_size=(2,),
+ )
+ else:
+ data = TensorDict(
+ history=ChatHistory(full=sample_history_assistant),
+ batch_size=(2,),
+ )
+
+ # Run wrapper
+ result = wrapper(data)
+ check_output_shapes(result, pad_output, requested_log_probs=not generate)
+
+ # Check output structure
+ for key in expected_out_keys:
+ assert key in result
+ assert hasattr(result[key], "__class__")
+
+ # Check specific outputs - always check everything
+ text_obj = result["text"]
+ assert hasattr(text_obj, "prompt")
+ assert hasattr(text_obj, "response")
+ assert hasattr(text_obj, "full")
+
+ if generate:
+ assert text_obj.response is not None
+ assert isinstance(text_obj.response, list)
+ assert isinstance(text_obj.response[0], str)
+
+ tokens_obj = result["tokens"]
+ if pad_output:
+ assert hasattr(tokens_obj, "prompt")
+ assert hasattr(tokens_obj, "response")
+ assert hasattr(tokens_obj, "full")
+ assert hasattr(tokens_obj, "padded")
+ assert all(tokens_obj.padded) == pad_output
+
+ if generate:
+ if pad_output:
+ assert tokens_obj.response is not None
+ else:
+ assert tokens_obj.get("response", as_list=True) is not None
+ if not pad_output:
+ response_tokens = result["tokens"].get("response", as_list=True)
+ assert isinstance(response_tokens, list)
+ else:
+ assert isinstance(tokens_obj.response, torch.Tensor)
+
+ masks_obj = result["masks"]
+ if pad_output:
+ assert hasattr(masks_obj, "all_attention_mask")
+ assert hasattr(masks_obj, "all_assistant_mask")
+ assert hasattr(masks_obj, "padded")
+ assert all(masks_obj.padded) == pad_output
+
+ log_probs_obj = result["log_probs"]
+ if pad_output:
+ assert hasattr(log_probs_obj, "prompt")
+ assert hasattr(log_probs_obj, "response")
+ assert hasattr(log_probs_obj, "full")
+ assert hasattr(log_probs_obj, "padded")
+ assert all(log_probs_obj.padded) == pad_output
+
+ # ================================================
+ # Text Input Mode Tests
+ # ================================================
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ @pytest.mark.parametrize("generate", [True, False], ids=["generate", "no_generate"])
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ def test_text_input_mode(
+ self,
+ wrapper_class,
+ vllm_instance,
+ transformers_instance,
+ sample_text,
+ generate,
+ pad_output,
+ ):
+ """Test text input mode with various configurations."""
+ model, tokenizer = vllm_instance
+
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="text",
+ generate=generate,
+ pad_output=pad_output,
+ )
+
+ # Check input keys
+ if generate:
+ assert wrapper.in_keys == [("text", "prompt")]
+ else:
+ assert wrapper.in_keys == [("text", "full")]
+
+ # Create input data
+ if generate:
+ data = TensorDict(text=Text(prompt=sample_text), batch_size=(2,))
+ else:
+ data = TensorDict(text=Text(full=sample_text), batch_size=(2,))
+
+ # Run wrapper
+ result = wrapper(data)
+ check_output_shapes(result, pad_output, requested_log_probs=not generate)
+
+ # Check output structure - always return everything
+ expected_keys = ["text", "masks", "tokens", "log_probs"]
+ for key in expected_keys:
+ assert key in result
+
+ # Check text output
+ text_obj = result["text"]
+ if generate:
+ assert text_obj.prompt == sample_text
+ else:
+ assert text_obj.full == sample_text
+ if generate:
+ assert text_obj.response is not None
+
+ # Check tokens output
+ tokens_obj = result["tokens"]
+ if generate:
+ if not pad_output:
+ response_tokens = tokens_obj.get("response", as_list=True)
+ assert isinstance(tokens_obj.get("response", as_list=True), list)
+ else:
+ assert isinstance(tokens_obj.response, torch.Tensor)
+
+ # ================================================
+ # Tokens Input Mode Tests
+ # ================================================
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ @pytest.mark.parametrize("generate", [True, False], ids=["generate", "no_generate"])
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ def test_tokens_input_mode(
+ self,
+ wrapper_class,
+ vllm_instance,
+ transformers_instance,
+ sample_tokens,
+ generate,
+ pad_output,
+ ):
+ """Test tokens input mode with various configurations."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ input_ids, attention_mask = sample_tokens
+
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="tokens",
+ attention_mask_key="attention_mask",
+ generate=generate,
+ pad_output=pad_output,
+ )
+
+ # Check input keys
+ assert (
+ wrapper.in_keys == [("tokens", "prompt")]
+ if generate
+ else [("tokens", "full")]
+ )
+
+ # Create input data
+ data = TensorDict(
+ tokens=Tokens(prompt=input_ids) if generate else Tokens(full=input_ids),
+ attention_mask=attention_mask,
+ batch_size=(2,),
+ )
+
+ # Run wrapper
+ result = wrapper(data)
+ check_output_shapes(result, pad_output, requested_log_probs=not generate)
+
+ # Check output structure
+ expected_keys = ["masks", "tokens", "log_probs"]
+ for key in expected_keys:
+ assert key in result
+
+ # Check tokens output
+ tokens_obj = result["tokens"]
+ if generate:
+ if not pad_output:
+ response_tokens = result["tokens"].get("response", as_list=True)
+ assert isinstance(response_tokens, list)
+ else:
+ assert isinstance(tokens_obj.response, torch.Tensor)
+
+ # ================================================
+ # Error Handling Tests
+ # ================================================
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_invalid_input_mode(
+ self, wrapper_class, vllm_instance, transformers_instance
+ ):
+ """Test that invalid input_mode raises an error."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ with pytest.raises(ValueError, match="input_mode must be one of"):
+ wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="invalid_mode",
+ )
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_missing_input_key(
+ self, wrapper_class, vllm_instance, transformers_instance, sample_history
+ ):
+ """Test that missing input key raises an error."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ input_key="history",
+ )
+
+ # Create data without the required key
+ data = TensorDict(batch_size=(2,))
+
+ with pytest.raises(ValueError, match="Expected 'history' key"):
+ wrapper(data)
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_invalid_history_type(
+ self, wrapper_class, vllm_instance, transformers_instance
+ ):
+ """Test that invalid history type raises an error."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ )
+
+ # Create data with wrong type
+ data = TensorDict(
+ history=ChatHistory(prompt="not a history object"), batch_size=(2,)
+ )
+
+ with pytest.raises(TypeError, match="Expected History object"):
+ wrapper(data)
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_generate_false_without_log_probs(
+ self, wrapper_class, vllm_instance, transformers_instance
+ ):
+ """Test that generate=False without return_log_probs=True raises an error."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ with pytest.raises(ValueError, match="return_log_probs must be True"):
+ wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ generate=False,
+ return_log_probs=False,
+ )
+
+ # ================================================
+ # Batch Size Tests
+ # ================================================
+
+ @pytest.mark.parametrize(
+ "batch_size", [1, 2, 3], ids=["batch_size_1", "batch_size_2", "batch_size_3"]
+ )
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_batch_sizes(
+ self,
+ wrapper_class,
+ vllm_instance,
+ transformers_instance,
+ batch_size,
+ pad_output,
+ ):
+ """Test wrapper with different batch sizes."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ # Create history with specified batch size
+ chats = [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": f"Question {i}?"},
+ ]
+ for i in range(batch_size)
+ ]
+ history = History.from_chats(chats)
+
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=True,
+ return_log_probs=True,
+ pad_output=pad_output,
+ )
+
+ data = TensorDict(history=ChatHistory(prompt=history), batch_size=(batch_size,))
+ result = wrapper(data)
+ check_output_shapes(
+ result, pad_output=wrapper.pad_output, requested_log_probs=False
+ )
+
+ # Check that all expected keys are present
+ expected_keys = ["text", "masks", "tokens", "log_probs"]
+ for key in expected_keys:
+ assert key in result
+
+ # Check batch size consistency
+ if pad_output:
+ # For padded output, tensors should have the correct batch dimension
+ assert len(result["text"].response) == batch_size
+ assert len(result["tokens"].response) == batch_size
+ else:
+ # For unpadded output, use as_list=True to get lists
+ response_text = result["text"].get("response", as_list=True)
+ response_tokens = result["tokens"].get("response", as_list=True)
+ assert len(response_text) == batch_size
+ assert len(response_tokens) == batch_size
+
+ # ================================================
+ # Custom Input Key Tests
+ # ================================================
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_custom_input_key(
+ self, wrapper_class, vllm_instance, transformers_instance, sample_history
+ ):
+ """Test wrapper with custom input key."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ input_key=("custom_history_key", "prompt"),
+ generate=True,
+ return_log_probs=True,
+ )
+
+ # Check input keys
+ assert wrapper.in_keys == [("custom_history_key", "prompt")]
+
+ # Create data with custom key
+ data = TensorDict(
+ custom_history_key=ChatHistory(prompt=sample_history), batch_size=(2,)
+ )
+ result = wrapper(data)
+ check_output_shapes(
+ result, pad_output=wrapper.pad_output, requested_log_probs=False
+ )
+
+ # Check that wrapper works correctly
+ expected_keys = ["text", "masks", "tokens", "log_probs"]
+ for key in expected_keys:
+ assert key in result
+
+ # ================================================
+ # Selective Output Tests
+ # ================================================
+
+ @pytest.mark.parametrize(
+ "return_log_probs", [True, False], ids=["log_probs", "no_log_probs"]
+ )
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_selective_outputs(
+ self,
+ wrapper_class,
+ vllm_instance,
+ transformers_instance,
+ sample_history,
+ return_log_probs,
+ ):
+ """Test wrapper with selective output configurations."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=True,
+ return_log_probs=return_log_probs,
+ )
+
+ # Check output keys
+ expected_out_keys = []
+ if wrapper.return_text:
+ expected_out_keys.append("text")
+ if wrapper.return_masks:
+ expected_out_keys.append("masks")
+ if wrapper.return_tokens:
+ expected_out_keys.append("tokens")
+ if return_log_probs:
+ expected_out_keys.append("log_probs")
+ if wrapper.return_history:
+ expected_out_keys.append("history")
+
+ assert wrapper.out_keys == expected_out_keys
+
+ # Run wrapper
+ data = TensorDict(history=ChatHistory(prompt=sample_history), batch_size=(2,))
+ result = wrapper(data)
+ check_output_shapes(
+ result, pad_output=wrapper.pad_output, requested_log_probs=False
+ )
+
+ # Check that only expected keys are present
+ for key in expected_out_keys:
+ assert key in result
+
+ # Check that unexpected keys are not present
+ all_possible_keys = ["text", "masks", "tokens", "log_probs"]
+ for key in all_possible_keys:
+ if key not in expected_out_keys:
+ assert key not in result
+
+ # ================================================
+ # Log-probs Only Mode Tests
+ # ================================================
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_log_probs_only_mode(
+ self,
+ wrapper_class,
+ vllm_instance,
+ transformers_instance,
+ sample_history_assistant,
+ ):
+ """Test wrapper in log-probs only mode (generate=False)."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=False, # Only compute log-probs
+ return_log_probs=True, # Must be True when generate=False
+ )
+
+ data = TensorDict(
+ history=ChatHistory(full=sample_history_assistant), batch_size=(2,)
+ )
+ result = wrapper(data)
+ check_output_shapes(
+ result, pad_output=wrapper.pad_output, requested_log_probs=True
+ )
+
+ # Check that log_probs are present
+ assert "log_probs" in result
+
+ # Check that response_text is None (no generation)
+ assert result["text"].response is None
+
+ # Check that prompt_logprobs are present
+ log_probs_obj = result["log_probs"]
+ assert log_probs_obj.get("full", as_list=True) is not None
+
+ # ================================================
+ # TensorClass Structure Tests
+ # ================================================
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_tensorclass_structure(
+ self, wrapper_class, vllm_instance, transformers_instance, sample_history
+ ):
+ """Test that TensorClass objects have the correct structure."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+ pad_output = False
+
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=True,
+ return_log_probs=True,
+ )
+
+ data = TensorDict(history=ChatHistory(prompt=sample_history), batch_size=(2,))
+ result = wrapper(data)
+
+ # Test Text TensorClass
+ text_obj = result["text"]
+ assert hasattr(text_obj, "prompt")
+ assert hasattr(text_obj, "response")
+ assert hasattr(text_obj, "full")
+
+ # Test Tokens TensorClass
+ tokens_obj = result["tokens"]
+ if pad_output:
+ # if not padded, we will fail to stack
+ assert hasattr(tokens_obj, "prompt")
+ assert hasattr(tokens_obj, "response")
+ assert hasattr(tokens_obj, "full")
+ assert hasattr(tokens_obj, "padded")
+ else:
+ assert (
+ tokens_obj.get("response", as_list=True) is not None
+ ) # if not padded, we will fail to stack
+
+ # Test LogProbs TensorClass
+ log_probs_obj = result["log_probs"]
+ if pad_output:
+ # if not padded, we will fail to stack
+ assert hasattr(log_probs_obj, "prompt")
+ assert hasattr(log_probs_obj, "response")
+ assert hasattr(log_probs_obj, "full")
+ assert hasattr(log_probs_obj, "padded")
+ else:
+ assert (
+ log_probs_obj.get("response", as_list=True) is not None
+ ) # if not padded, we will fail to stack
+
+ # Test Masks TensorClass
+ masks_obj = result["masks"]
+ if pad_output:
+ # if not padded, we will fail to stack
+ assert hasattr(masks_obj, "all_attention_mask")
+ assert hasattr(masks_obj, "all_assistant_mask")
+ assert hasattr(masks_obj, "padded")
+
+ # ================================================
+ # Unpadded Output Tests (with as_list=True)
+ # ================================================
+
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_unpadded_output_with_as_list(
+ self, wrapper_class, vllm_instance, transformers_instance, sample_history
+ ):
+ """Test unpadded output using as_list=True to avoid stacking issues."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=True,
+ return_log_probs=True,
+ pad_output=False, # Unpadded output
+ )
+
+ data = TensorDict(history=ChatHistory(prompt=sample_history), batch_size=(2,))
+ result = wrapper(data)
+ check_output_shapes(
+ result, pad_output=wrapper.pad_output, requested_log_probs=False
+ )
+
+ # Use as_list=True to get lists instead of trying to stack
+ text_list = result.get("text", as_list=True)
+ tokens_list = result.get("tokens", as_list=True)
+ masks_list = result.get("masks", as_list=True)
+ log_probs_list = result.get("log_probs", as_list=True)
+
+ # Check that we get lists
+ assert isinstance(text_list.response, list)
+ assert isinstance(tokens_list.get("response", as_list=True), list)
+ assert isinstance(log_probs_list.get("response", as_list=True), list)
+
+ # Check list lengths
+ assert len(text_list.response) == 2
+ assert len(tokens_list.get("response", as_list=True)) == 2
+ assert len(log_probs_list.get("response", as_list=True)) == 2
+
+ # Check that individual elements are tensors
+ assert isinstance(text_list.response[0], str)
+ assert isinstance(tokens_list.get("response", as_list=True)[0], torch.Tensor)
+ assert isinstance(log_probs_list.get("response", as_list=True)[0], torch.Tensor)
+
+ @pytest.mark.parametrize("num_samples", [2], ids=["num_samples_2"])
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ @pytest.mark.parametrize(
+ "return_log_probs", [True, False], ids=["log_probs", "no_log_probs"]
+ )
+ @pytest.mark.parametrize(
+ "input_mode", ["history", "text", "tokens"], ids=["history", "text", "tokens"]
+ )
+ @pytest.mark.parametrize(
+ "wrapper_class",
+ [vLLMWrapper, TransformersWrapperMaxTokens],
+ ids=["vllm", "transformers"],
+ )
+ def test_num_samples(
+ self,
+ wrapper_class,
+ vllm_instance,
+ transformers_instance,
+ sample_history,
+ sample_text,
+ sample_tokens,
+ num_samples,
+ pad_output,
+ return_log_probs,
+ input_mode,
+ ):
+ """Test wrapper with num_samples."""
+ if wrapper_class == vLLMWrapper:
+ model, tokenizer = vllm_instance
+ else:
+ model, tokenizer = transformers_instance
+
+ wrapper = wrapper_class(
+ model,
+ tokenizer=tokenizer,
+ input_mode=input_mode,
+ generate=True,
+ return_log_probs=return_log_probs,
+ pad_output=pad_output,
+ num_samples=num_samples,
+ )
+ if input_mode == "history":
+ data = TensorDict(
+ history=ChatHistory(prompt=sample_history), batch_size=(2,)
+ )
+ elif input_mode == "text":
+ data = TensorDict(text=Text(prompt=sample_text), batch_size=(2,))
+ elif input_mode == "tokens":
+ data = TensorDict(tokens=Tokens(prompt=sample_tokens[0]), batch_size=(2,))
+ else:
+ raise ValueError(f"Invalid input mode: {input_mode}")
+ result = wrapper(data)
+ assert result.batch_size == (2, num_samples)
+ check_output_shapes(
+ result, pad_output=wrapper.pad_output, requested_log_probs=False
+ )
+
+
+class TestKLTransforms:
+ """Comprehensive tests for KL-related transforms with different input modes and configurations."""
+
+ @pytest.mark.skipif(not _has_transformers, reason="transformers not available")
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ @pytest.mark.parametrize(
+ "assistant_only", [True, False], ids=["assistant_only", "all_tokens"]
+ )
+ @pytest.mark.parametrize(
+ "input_mode", ["history", "text", "tokens"], ids=["history", "text", "tokens"]
+ )
+ def test_retrieve_log_prob_input_modes(
+ self,
+ transformers_instance,
+ sample_history_assistant,
+ sample_text,
+ sample_tokens,
+ pad_output,
+ assistant_only,
+ input_mode,
+ ):
+ """Test RetrieveLogProb with different input modes and assistant_only settings."""
+ model, tokenizer = transformers_instance
+
+ # Skip invalid combinations
+ if assistant_only and input_mode != "history":
+ pytest.skip("assistant_only=True requires input_mode='history'")
+
+ # Create test data based on input mode
+ if input_mode == "history":
+ history = sample_history_assistant
+ data = TensorDict(history=ChatHistory(full=history), batch_size=(2,))
+ elif input_mode == "text":
+ history = None # Not used in text mode
+ prompts = sample_text
+ data = TensorDict(text=Text(full=prompts), batch_size=(2,))
+ elif input_mode == "tokens":
+ history = None # Not used in tokens mode
+ prompts = sample_tokens
+ data = TensorDict(
+ tokens=Tokens(full=prompts[0]),
+ masks=Masks(all_attention_mask=prompts[1]),
+ batch_size=(2,),
+ )
+ else:
+ raise ValueError(f"Invalid input_mode: {input_mode}")
+
+ # Create reference model with appropriate input mode
+ ref_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode=input_mode,
+ generate=False,
+ pad_output=pad_output,
+ )
+
+ # Create RetrieveLogProb transform
+ transform = RetrieveLogProb(
+ ref_model,
+ assistant_only=assistant_only,
+ tokenizer=tokenizer,
+ )
+
+ # Apply transform
+ result = transform(data)
+
+ # The log-probs key should be based on the model's log_probs_key
+ log_probs_key = (ref_model.log_probs_key, "full")
+ assert log_probs_key in result
+
+ # Check log-probs structure
+ if pad_output:
+ log_probs = result.get(log_probs_key)
+ assert isinstance(log_probs, torch.Tensor)
+ assert log_probs.shape[0] == 2 # batch size
+ else:
+ # For unpadded output, we get a list of tensors
+ log_probs = result.get(log_probs_key, as_list=True)
+ assert isinstance(log_probs, list)
+ assert len(log_probs) == 2 # batch size
+
+ @pytest.mark.skipif(not _has_transformers, reason="transformers not available")
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ @pytest.mark.parametrize(
+ "assistant_only", [True, False], ids=["assistant_only", "all_tokens"]
+ )
+ @pytest.mark.parametrize(
+ "input_mode", ["history", "text", "tokens"], ids=["history", "text", "tokens"]
+ )
+ def test_retrieve_kl_input_modes(
+ self,
+ transformers_instance,
+ sample_history_assistant,
+ sample_text,
+ sample_tokens,
+ pad_output,
+ assistant_only,
+ input_mode,
+ ):
+ """Test RetrieveKL with different input modes and assistant_only settings."""
+ model, tokenizer = transformers_instance
+
+ # Skip invalid combinations
+ if assistant_only and input_mode != "history":
+ pytest.skip("assistant_only=True requires input_mode='history'")
+
+ # Create test data based on input mode
+ if input_mode == "history":
+ history = sample_history_assistant
+ data = TensorDict(history=ChatHistory(full=history), batch_size=(2,))
+ elif input_mode == "text":
+ history = None # Not used in text mode
+ prompts = sample_text
+ data = TensorDict(text=Text(full=prompts), batch_size=(2,))
+ elif input_mode == "tokens":
+ history = None # Not used in tokens mode
+ prompts = sample_tokens
+ data = TensorDict(
+ tokens=Tokens(full=prompts[0]),
+ masks=Masks(all_attention_mask=prompts[1]),
+ batch_size=(2,),
+ )
+ else:
+ raise ValueError(f"Invalid input_mode: {input_mode}")
+
+ # Create generation and reference models with appropriate input mode
+ gen_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode=input_mode,
+ generate=False,
+ pad_output=pad_output,
+ log_probs_key="gen_log_probs",
+ )
+
+ ref_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode=input_mode,
+ generate=False,
+ pad_output=pad_output,
+ log_probs_key="ref_log_probs",
+ )
+
+ # Create RetrieveKL transform
+ transform = RetrieveKL(
+ gen_model=gen_model,
+ ref_model=ref_model,
+ assistant_only=assistant_only,
+ tokenizer=tokenizer,
+ )
+
+ # Apply transform
+ data = data.to_lazystack(0)
+ result = transform(data)
+
+ # Check that KL is present
+ # Check that both log-probs and KL are present
+ assert ("gen_log_probs", "full") in result
+ assert ("ref_log_probs", "full") in result
+ assert "kl" in result
+
+ # Check KL structure
+ if pad_output:
+ kl = result.get("kl")
+ assert isinstance(kl, torch.Tensor)
+ assert kl.shape[0] == 2 # batch size
+ else:
+ kl = result.get("kl", as_list=True)
+ # For unpadded output, we get a list of tensors
+ assert isinstance(kl, list)
+ assert len(kl) == 2 # batch size
+
+ @pytest.mark.skipif(not _has_transformers, reason="transformers not available")
+ def test_retrieve_log_prob_assistant_only_validation(
+ self, transformers_instance, sample_text
+ ):
+ """Test that assistant_only=True with non-history input_mode raises an error."""
+ model, tokenizer = transformers_instance
+
+ # Create reference model with text input mode
+ ref_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="text",
+ generate=False,
+ return_log_probs=True,
+ pad_output=True,
+ )
+
+ # This should raise an error
+ with pytest.raises(
+ ValueError, match="The model must have `input_mode='history'` when"
+ ):
+ RetrieveLogProb(
+ ref_model,
+ assistant_only=True, # This should fail with text input_mode
+ tokenizer=tokenizer,
+ )
+
+ @pytest.mark.skipif(not _has_transformers, reason="transformers not available")
+ def test_retrieve_kl_assistant_only_validation(
+ self, transformers_instance, sample_text
+ ):
+ """Test that assistant_only=True with non-history input_mode raises an error."""
+ model, tokenizer = transformers_instance
+
+ # Create models with text input mode
+ gen_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="text",
+ generate=False,
+ return_log_probs=True,
+ pad_output=True,
+ log_probs_key="gen_log_probs",
+ )
+
+ ref_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="text",
+ generate=False,
+ return_log_probs=True,
+ pad_output=True,
+ log_probs_key="ref_log_probs",
+ )
+
+ # This should raise an error
+ with pytest.raises(
+ ValueError, match="The model must have `input_mode='history'` when"
+ ):
+ RetrieveKL(
+ gen_model=gen_model,
+ ref_model=ref_model,
+ assistant_only=True, # This should fail with text input_mode
+ tokenizer=tokenizer,
+ )
+
+ @pytest.mark.skipif(not _has_transformers, reason="transformers not available")
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ def test_retrieve_kl_pad_output_consistency(
+ self, transformers_instance, sample_history_assistant, pad_output
+ ):
+ """Test that RetrieveKL enforces pad_output consistency between models."""
+ model, tokenizer = transformers_instance
+
+ # Create models with different pad_output settings
+ gen_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=False,
+ return_log_probs=True,
+ pad_output=pad_output,
+ log_probs_key="gen_log_probs",
+ )
+
+ ref_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=False,
+ return_log_probs=True,
+ pad_output=not pad_output, # Different pad_output setting
+ log_probs_key="ref_log_probs",
+ )
+
+ # This should raise an error
+ with pytest.raises(ValueError, match="pad_output mismatch"):
+ RetrieveKL(
+ gen_model=gen_model,
+ ref_model=ref_model,
+ assistant_only=False,
+ tokenizer=tokenizer,
+ )
+
+ @pytest.mark.skipif(not _has_transformers, reason="transformers not available")
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ def test_kl_computation_transform(
+ self, transformers_instance, sample_history_assistant, pad_output
+ ):
+ """Test the KLComputation transform directly."""
+ model, tokenizer = transformers_instance
+
+ # Create models
+ gen_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=False,
+ return_log_probs=True,
+ pad_output=pad_output,
+ log_probs_key="gen_log_probs",
+ )
+
+ ref_model = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=False,
+ return_log_probs=True,
+ pad_output=pad_output,
+ log_probs_key="ref_log_probs",
+ )
+
+ # Create data
+ data = TensorDict(
+ history=ChatHistory(full=sample_history_assistant), batch_size=(2,)
+ )
+
+ # Get log-probs from both models
+ data = data.to_lazystack(0)
+ gen_result = gen_model(data)
+ ref_result = ref_model(data)
+
+ # Create next tensordict with log-probs and reward
+ next_td = TensorDict(batch_size=(2,)).to_lazystack(0)
+ next_td.update(gen_result, keys_to_update=[("gen_log_probs", "full")])
+ next_td.update(ref_result, keys_to_update=[("ref_log_probs", "full")])
+ next_td.update({"reward": torch.randn(2, 1, 1)})
+
+ # Create KLComputation transform
+ kl_transform = KLComputation(
+ gen_log_probs_full_key=("gen_log_probs", "full"),
+ ref_log_probs_full_key=("ref_log_probs", "full"),
+ kl_key="kl",
+ add_to_reward=True,
+ coeff=1.0,
+ )
+
+ # Apply transform
+ result = kl_transform(data.set("next", next_td))
+
+ # Check that KL is computed
+ result = result["next"]
+ assert "kl" in result
+
+ if pad_output:
+ kl = result.get("kl")
+ assert isinstance(kl, torch.Tensor)
+ assert kl.shape[0] == 2 # batch size
+ else:
+ kl = result.get("kl", as_list=True)
+ assert isinstance(kl, list)
+ assert len(kl) == 2 # batch size
+
+ # Check that reward is modified
+ assert "reward" in result
+ reward = result.get("reward")
+ assert reward is not None
+
+
+class TestLogProbsComparison:
+ """Test log-probability consistency between vLLM and Transformers wrappers."""
+
+ @pytest.mark.skipif(not _has_vllm, reason="vllm not available")
+ @pytest.mark.skipif(not _has_transformers, reason="transformers not available")
+ @pytest.mark.parametrize(
+ "input_mode", ["history", "text", "tokens"], ids=["history", "text", "tokens"]
+ )
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ def test_log_probs_consistency(
+ self,
+ vllm_instance,
+ transformers_instance,
+ input_mode,
+ pad_output,
+ sample_history,
+ sample_text,
+ sample_tokens,
+ ):
+ """Test that log-probabilities are consistent between vLLM and Transformers wrappers."""
+ vllm_model, vllm_tokenizer = vllm_instance
+ tf_model, tf_tokenizer = transformers_instance
+
+ # Create test data based on input mode
+ if input_mode == "history":
+ history = sample_history
+ data = TensorDict(history=history, batch_size=(2,))
+ input_key = "history"
+ elif input_mode == "text":
+ history = None # Not used in text mode
+ prompts = sample_text
+ data = TensorDict(text=prompts, batch_size=(2,))
+ input_key = "text"
+ elif input_mode == "tokens":
+ history = None # Not used in tokens mode
+ prompts = sample_tokens
+ data = TensorDict(
+ input_ids=prompts[0],
+ attention_mask=prompts[1],
+ batch_size=(2,),
+ )
+ input_key = "input_ids"
+ else:
+ raise ValueError(f"Invalid input_mode: {input_mode}")
+
+ # Create vLLM wrapper for generation
+ vllm_gen_wrapper = vLLMWrapper(
+ vllm_model,
+ tokenizer=vllm_tokenizer,
+ input_mode=input_mode,
+ input_key=input_key,
+ generate=True,
+ pad_output=pad_output,
+ generate_kwargs={"max_tokens": 5, "temperature": 0.0}, # Deterministic
+ )
+
+ # Create Transformers wrapper for generation
+ tf_gen_wrapper = TransformersWrapper(
+ tf_model,
+ tokenizer=tf_tokenizer,
+ input_mode=input_mode,
+ input_key=input_key,
+ generate=True,
+ pad_output=pad_output,
+ generate_kwargs={
+ "max_new_tokens": 5,
+ "do_sample": False,
+ "temperature": 0.0,
+ }, # Deterministic
+ )
+
+ # Step 1: Generate tokens with both wrappers
+ vllm_gen_result = vllm_gen_wrapper(data.copy())
+ tf_gen_wrapper(data.copy())
+
+ # Step 2: Extract generated tokens and create new input for log-probs computation
+ if input_mode == "history":
+ # For history mode, we need to create new history with generated responses
+ generated_texts = vllm_gen_result["text"].response
+ new_chats = []
+ assert history is not None # Type assertion for linter
+ for chat, gen_text in _zip_strict(history.unbind(0), generated_texts):
+ new_chat = chat.copy().append(
+ History(role="assistant", content=gen_text)
+ )
+ new_chats.append(new_chat)
+ new_history = lazy_stack(new_chats)
+ new_data = TensorDict(history=new_history, batch_size=(2,))
+ elif input_mode == "text":
+ # For text mode, concatenate original text with generated text
+ original_texts = data["text"]
+ generated_texts = vllm_gen_result["text"].response
+ new_texts = [
+ orig + gen for orig, gen in zip(original_texts, generated_texts)
+ ]
+ new_data = TensorDict(text=new_texts, batch_size=(2,))
+ elif input_mode == "tokens":
+ # For tokens mode, concatenate original tokens with generated tokens
+ original_tokens = data["input_ids"]
+ generated_tokens = vllm_gen_result["tokens"].response
+ if pad_output:
+ # Remove padding from generated tokens
+ mask = generated_tokens != vllm_tokenizer.pad_token_id
+ new_tokens = []
+ for i in range(len(original_tokens)):
+ valid_tokens = generated_tokens[i][mask[i]]
+ combined = torch.cat([original_tokens[i], valid_tokens])
+ new_tokens.append(combined)
+ new_tokens = torch.stack(new_tokens)
+ else:
+ new_tokens = []
+ for i in range(len(original_tokens)):
+ combined = torch.cat([original_tokens[i], generated_tokens[i]])
+ new_tokens.append(combined)
+ new_data = TensorDict(input_ids=new_tokens, batch_size=(2,))
+ else:
+ raise ValueError(f"Invalid input_mode: {input_mode}")
+
+ # Step 3: Create log-probs only wrappers
+ vllm_lp_wrapper = vLLMWrapper(
+ vllm_model,
+ tokenizer=vllm_tokenizer,
+ input_mode=input_mode,
+ input_key=input_key,
+ generate=False,
+ pad_output=pad_output,
+ )
+
+ tf_lp_wrapper = TransformersWrapper(
+ tf_model,
+ tokenizer=tf_tokenizer,
+ input_mode=input_mode,
+ input_key=input_key,
+ generate=False,
+ pad_output=pad_output,
+ )
+
+ # Step 4: Compute log-probs for the full sequence (original + generated)
+ vllm_lp_result = vllm_lp_wrapper(new_data.copy())
+ tf_lp_result = tf_lp_wrapper(new_data.copy())
+
+ from tensordict import assert_close
+
+ assert_close(
+ vllm_lp_result, tf_lp_result, atol=1e-1, rtol=1e-1, intersection=True
+ )
+
+
+class TestDistributionMethods:
+ """Test the new distribution methods and masking strategies."""
+
+ @pytest.mark.skipif(not _has_vllm, reason="vllm not available")
+ @pytest.mark.parametrize("masking_strategy", ["sft", "rlhf", "generic"])
+ def test_vllm_distribution_methods(
+ self, vllm_instance, sample_history_assistant, sample_tokens, masking_strategy
+ ):
+ """Test that vLLM wrapper distribution methods work correctly."""
+ model, tokenizer = vllm_instance
+
+ # vLLM doesn't support get_dist methods
+ wrapper = vLLMWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=False,
+ return_log_probs=True,
+ )
+
+ # Create test data
+ td = TensorDict({"history": sample_history_assistant}, batch_size=(2,))
+
+ # Test that all distribution methods raise NotImplementedError
+ with pytest.raises(NotImplementedError, match="vLLM does not return logits"):
+ wrapper.get_dist(td)
+
+ with pytest.raises(NotImplementedError, match="vLLM does not return logits"):
+ wrapper._get_sft_dist(td)
+
+ with pytest.raises(NotImplementedError, match="vLLM does not return logits"):
+ wrapper._get_rlhf_dist(td)
+
+ with pytest.raises(NotImplementedError, match="vLLM does not return logits"):
+ wrapper._get_generic_dist(td)
+
+ @pytest.mark.skipif(not _has_transformers, reason="transformers not available")
+ @pytest.mark.parametrize("masking_strategy", ["sft", "rlhf", "generic"])
+ @pytest.mark.parametrize("pad_output", [True, False], ids=["padded", "unpadded"])
+ def test_transformers_distribution_methods(
+ self,
+ transformers_instance,
+ sample_history_assistant,
+ sample_tokens,
+ masking_strategy,
+ pad_output,
+ ):
+ """Test that Transformers wrapper distribution methods work correctly."""
+ model, tokenizer = transformers_instance
+
+ # Use tokens input mode for SFT, history for RLHF/generic
+ if masking_strategy == "sft":
+ input_mode = "tokens"
+ input_ids, attention_mask = sample_tokens
+ assistant_mask = attention_mask.bool().clone()
+ assistant_mask[:, : attention_mask.shape[-1] // 2] = False
+ input_data = {
+ "tokens": Tokens(full=input_ids),
+ "masks": Masks(
+ all_attention_mask=attention_mask.bool(),
+ all_assistant_mask=assistant_mask,
+ ),
+ }
+
+ # Create test data with correct batch size
+ td = TensorDict(input_data, batch_size=(2,)).to_lazystack(0)
+ if not pad_output:
+ for _td in td.unbind(0):
+ _td["tokens"].full = _td["tokens"].full[
+ _td["masks"].all_attention_mask
+ ]
+ _td["masks"].all_assistant_mask = _td["masks"].all_assistant_mask[
+ _td["masks"].all_attention_mask
+ ]
+ _td["masks"].all_attention_mask = _td["masks"].all_attention_mask[
+ _td["masks"].all_attention_mask
+ ]
+ else:
+ input_mode = "history"
+ input_data = {"history": ChatHistory(full=sample_history_assistant)}
+
+ # Create test data with correct batch size
+ td = TensorDict(input_data, batch_size=(2,)).to_lazystack(0)
+
+ wrapper = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode=input_mode,
+ generate=False,
+ pad_output=pad_output,
+ )
+
+ # Test the appropriate distribution method
+ if masking_strategy == "sft":
+ dist = wrapper._get_sft_dist(td)
+ elif masking_strategy == "rlhf":
+ dist = wrapper._get_rlhf_dist(td)
+ elif masking_strategy == "generic":
+ dist = wrapper._get_generic_dist(td)
+
+ # Verify that we get a distribution
+ assert dist is not None
+ assert hasattr(dist, "log_prob")
+ assert hasattr(dist, "sample")
+
+ # Test that logits are available in the output
+ td_out = wrapper(td.copy())
+
+ # Test log_prob computation
+ if masking_strategy == "sft":
+ # For SFT, we need tokens to compute log_prob
+ tokens = td_out.get(
+ ("tokens", "full"),
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=tokenizer.pad_token_id,
+ )
+ if tokens is not None:
+ log_probs = dist.log_prob(tokens.long())
+ assert log_probs.shape == tokens.shape
+ else:
+ # For RLHF/generic, we can test with dummy tokens
+ logits = td_out.get("logits")
+ if logits is not None:
+ dummy_tokens = torch.randint(0, logits.shape[-1], logits.shape[:-1])
+ log_probs = dist.log_prob(dummy_tokens)
+ assert log_probs.shape == dummy_tokens.shape
+
+ @pytest.mark.skipif(not _has_transformers, reason="transformers not available")
+ def test_transformers_custom_masking(
+ self, transformers_instance, sample_history_assistant
+ ):
+ """Test custom masking functionality."""
+ model, tokenizer = transformers_instance
+
+ wrapper = TransformersWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=False,
+ return_log_probs=True,
+ pad_output=True,
+ )
+
+ td = TensorDict(
+ {"history": ChatHistory(full=sample_history_assistant)}, batch_size=(2,)
+ )
+
+ # Get the actual logits shape from the wrapper
+ result = wrapper(td)
+ lp = result["log_probs"].get("full")
+
+ # Create a custom mask matching the logits shape
+ custom_mask = torch.zeros_like(lp, dtype=torch.bool)
+ custom_mask[:, :5] = True # Only first 5 tokens
+
+ dist = wrapper._get_dist_with_custom_mask(td, custom_mask)
+
+ assert dist is not None
+ assert hasattr(dist, "log_prob")
+
+
+if __name__ == "__main__":
+ args, unknown = argparse.ArgumentParser().parse_known_args()
+ pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
diff --git a/test/test_specs.py b/test/test_specs.py
index d984db64c3a..44a10ea0a3e 100644
--- a/test/test_specs.py
+++ b/test/test_specs.py
@@ -3938,6 +3938,16 @@ def test_sample(self):
assert nts.rand((2,)).data == "example_data"
assert nts.zero((2,)).data == "example_data"
+ def test_feature_dims(self):
+ nts = NonTensor(shape=(3, 4), example_data="example_data")
+ assert nts.feature_dims == 2
+ nts = NonTensor(shape=(3, 4), example_data="example_data", feature_dims=1)
+ assert nts.feature_dims == 1
+ assert isinstance(nts.zeros(), NonTensorStack)
+ assert isinstance(nts.zeros(2), NonTensorStack)
+ assert isinstance(nts.zeros()[0], NonTensorData)
+ assert nts.rand((2,)).shape == (2, 3, 4)
+
def test_example_data_ineq(self):
nts0 = NonTensor(shape=(3, 4), example_data="example_data")
nts1 = NonTensor(shape=(3, 4), example_data="example_data 2")
diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py
index 0d76124b73f..bf255f40d78 100644
--- a/torchrl/collectors/collectors.py
+++ b/torchrl/collectors/collectors.py
@@ -56,7 +56,6 @@
WeightUpdaterBase,
)
from torchrl.data import ReplayBuffer
-from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.env_creator import EnvCreator
@@ -176,7 +175,6 @@ def weight_updater(self, value: WeightUpdaterBase | None):
def _get_policy_and_device(
self,
policy: Callable[[Any], Any] | None = None,
- observation_spec: TensorSpec = None,
policy_device: Any = NO_DEFAULT,
env_maker: Any | None = None,
env_maker_kwargs: dict[str, Any] | None = None,
@@ -187,7 +185,6 @@ def _get_policy_and_device(
Args:
policy (TensorDictModule, optional): a policy to be used
- observation_spec (TensorSpec, optional): spec of the observations
policy_device (torch.device, optional): the device where the policy should be placed.
Defaults to self.policy_device
env_maker (a callable or a batched env, optional): the env_maker function for this device/policy pair.
@@ -201,7 +198,7 @@ def _get_policy_and_device(
env = getattr(self, "env", None)
policy = _make_compatible_policy(
policy,
- observation_spec,
+ getattr(env, "observation_spec", None),
env=env,
env_maker=env_maker,
env_maker_kwargs=env_maker_kwargs,
@@ -800,9 +797,13 @@ def __init__(
self.reset_when_done = reset_when_done
self.n_env = self.env.batch_size.numel()
+ if hasattr(policy, "register_collector"):
+ policy.register_collector(self)
+ if hasattr(self.env, "register_collector"):
+ self.env.register_collector(self)
+
(self.policy, self.get_weights_fn,) = self._get_policy_and_device(
policy=policy,
- observation_spec=self.env.observation_spec,
)
if isinstance(self.policy, nn.Module):
self.policy_weights = TensorDict.from_module(
@@ -1271,7 +1272,8 @@ def cuda_check(tensor: torch.Tensor):
self.replay_buffer.extend(tensordict_out)
if self.verbose:
torchrl_logger.info(
- f"Collector: Added {tensordict_out.numel()} frames to replay buffer. Yielding."
+ f"Collector: Added {tensordict_out.numel()} frames to replay buffer. "
+ "Buffer write count: {self.replay_buffer.write_count}. Yielding."
)
yield
else:
@@ -1356,7 +1358,7 @@ def start(self):
"""
if self.replay_buffer is None:
raise RuntimeError("Replay buffer must be defined for execution.")
- if not hasattr(self, "_thread") or not self._thread.is_alive():
+ if not self.is_running():
self._stop = False
self._thread = threading.Thread(target=self._run_iterator)
self._thread.daemon = (
@@ -1369,6 +1371,9 @@ def _run_iterator(self):
if self._stop:
return
+ def is_running(self):
+ return hasattr(self, "_thread") and self._thread.is_alive()
+
def async_shutdown(
self, timeout: float | None = None, close_env: bool = True
) -> None:
diff --git a/torchrl/collectors/llm/base.py b/torchrl/collectors/llm/base.py
index 830eff36b85..a76bb2f3662 100644
--- a/torchrl/collectors/llm/base.py
+++ b/torchrl/collectors/llm/base.py
@@ -173,7 +173,7 @@ def __init__(
# disguise the queue as a replay buffer
replay_buffer = _QueueAsRB(queue)
if dialog_turns_per_batch is None and yield_completed_trajectories:
- dialog_turns_per_batch = 0
+ dialog_turns_per_batch = 1
super().__init__(
create_env_fn=env,
policy=policy,
@@ -189,6 +189,9 @@ def __init__(
extend_buffer=True,
postproc=postproc,
)
+ if hasattr(self.policy, "register_collector"):
+ self.policy.register_collector(self)
+
if yield_only_last_steps is None:
yield_only_last_steps = False
@@ -322,6 +325,8 @@ def _rollout_all(self) -> TensorDictBase: # A simplified version of rollout
return trajectory.view(-1)
return trajectory
+ _result_numel = 0
+
def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rollout
if self._shuttle is None:
raise RuntimeError("Data shuttle not found")
@@ -332,7 +337,7 @@ def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rol
collected_steps = 0
dones = torch.zeros(self.env.batch_size, dtype=torch.bool)
while True:
- if self._trajectory_queue:
+ if self._result_numel >= self.dialog_turns_per_batch:
break
env_input = self.policy(next_output)
cur_output, next_output = self.env.step_and_maybe_reset(env_input)
@@ -356,18 +361,24 @@ def _rollout_yield_trajs(self) -> TensorDictBase: # A simplified version of rol
if dones.any():
for idx in dones.nonzero(as_tuple=True)[0].tolist():
if not self.yield_only_last_steps:
- self._trajectory_queue.append(
- lazy_stack(self._yield_queues[idx], -1)
- )
+ _result = lazy_stack(self._yield_queues[idx], -1)
+ self._trajectory_queue.append(_result)
else:
# FIXME: We need to increment the step count here because iterator() won't
# see the extra steps
# We use lazy-stack because unsqueeze doesn't nest the strings in lists
- self._trajectory_queue.append(
- lazy_stack([self._yield_queues[idx][-1]])
- )
+ _result = lazy_stack([self._yield_queues[idx][-1]])
+ self._trajectory_queue.append(_result)
+ self._result_numel += _result.numel()
self._yield_queues[idx].clear()
- result = self._trajectory_queue.popleft()
+ result = [self._trajectory_queue.popleft()]
+ elt = result[0].numel()
+ self._result_numel -= result[0].numel()
+ while elt < self.dialog_turns_per_batch:
+ result.append(self._trajectory_queue.popleft())
+ elt += result[-1].numel()
+ self._result_numel -= result[-1].numel()
+ result = torch.cat(result, -1)
if self.verbose:
torchrl_logger.info(
f"LLMCollector: Yielding completed trajectory with shape {result.shape}."
diff --git a/torchrl/collectors/llm/ray_collector.py b/torchrl/collectors/llm/ray_collector.py
index 255d92f2192..32330e6ed9d 100644
--- a/torchrl/collectors/llm/ray_collector.py
+++ b/torchrl/collectors/llm/ray_collector.py
@@ -170,6 +170,9 @@ def start(self):
pending_task = self._collector.start.remote()
return ray.get(pending_task)
+ def is_running(self):
+ return ray.get(self._collector.is_running.remote())
+
def shutdown(self):
"""Shuts down the collector."""
pending_task = self._collector.shutdown.remote()
diff --git a/torchrl/data/llm/__init__.py b/torchrl/data/llm/__init__.py
index b7a5d1323f2..4ecf4d61098 100644
--- a/torchrl/data/llm/__init__.py
+++ b/torchrl/data/llm/__init__.py
@@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
-from .chat import ContentBase, History
from .common import LLMData
from .dataset import (
create_infinite_iterator,
@@ -11,6 +10,7 @@
TensorDictTokenizer,
TokenizedDatasetLoader,
)
+from .history import add_chat_template, ContentBase, History
from .prompt import PromptData, PromptTensorDictTokenizer
from .reward import PairwiseDataset, RewardData
from .topk import TopKRewardSelector
@@ -24,6 +24,7 @@
"LLMData",
"PairwiseDataset",
"PromptData",
+ "add_chat_template",
"PromptTensorDictTokenizer",
"RewardData",
"RolloutFromModel",
diff --git a/torchrl/data/llm/chat.py b/torchrl/data/llm/history.py
similarity index 54%
rename from torchrl/data/llm/chat.py
rename to torchrl/data/llm/history.py
index 5391b883c11..8cfe713f386 100644
--- a/torchrl/data/llm/chat.py
+++ b/torchrl/data/llm/history.py
@@ -21,7 +21,13 @@
from tensordict.utils import _maybe_correct_neg_dim
from torchrl._utils import logger as torchrl_logger
+try:
+ import transformers
+except ImportError:
+ transformers = None
+
+# Global storage for custom templates and their metadata
_CHAT_TEMPLATES = {
"chatml_format": """{% for message in messages %}
{%- if message['role'] == 'assistant' %}
@@ -40,7 +46,7 @@
{%- if messages[0]['role'] == 'system' %}
{{- messages[0]['content'] }}
{%- else %}
- {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
+ {{- 'You are a helpful assistant.' }}
{%- endif %}
{{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n" }}
{%- for tool in tools %}
@@ -52,7 +58,7 @@
{%- if messages[0]['role'] == 'system' %}
{{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}
{%- else %}
- {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}
+ {{- '<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
@@ -92,8 +98,176 @@
{% generation %}{{- '<|im_start|>assistant\\n' }}{% endgeneration %}
{%- endif %}
""",
+ "dialogpt": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ message['content'] }}{% endgeneration %}{{ eos_token }}{% elif message['role'] == 'user' %}{{ message['content'] }}{{ eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ ' ' }}{% endgeneration %}{% endif %}""",
+ "falcon": """{% for message in messages %}{% if message['role'] == 'assistant' %}{% generation %}{{ 'Assistant: ' + message['content'] }}{% endgeneration %}\n\n{% elif message['role'] == 'user' %}{{ 'User: ' + message['content'] }}\n\n{% elif message['role'] == 'system' %}{{ message['content'] }}\n\n{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ 'Assistant: ' }}{% endgeneration %}{% endif %}""",
+ "deepseek": """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{% generation %}{{ 'Assistant: ' + message['content'] + eos_token }}{% endgeneration %}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{% generation %}{{ 'Assistant:' }}{% endgeneration %}{% endif %}""",
+ "llama": """{{- bos_token }}
+{%- if messages[0]['role'] == 'system' %}
+ {%- set system_message = messages[0]['content']|trim %}
+ {%- set messages = messages[1:] %}
+{%- else %}
+ {%- set system_message = "" %}
+{%- endif %}
+{%- if system_message %}
+ {{- "<|header_start|>system<|header_end|>\n\n" }}
+ {{- system_message }}
+ {{- "<|eot|>" }}
+{%- endif %}
+{%- for message in messages %}
+ {%- if message['role'] == 'assistant' %}
+ {% generation %}{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
+ {%- if message['content'] is string %}
+ {{- message['content'] }}
+ {%- else %}
+ {%- for content in message['content'] %}
+ {%- if content['type'] == 'text' %}
+ {{- content['text'] | trim }}
+ {%- endif %}
+ {%- endfor %}
+ {%- endif %}
+ {{- "<|eot|>" }}{% endgeneration %}
+ {%- else %}
+ {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
+ {%- if message['content'] is string %}
+ {{- message['content'] }}
+ {%- else %}
+ {%- for content in message['content'] %}
+ {%- if content['type'] == 'text' %}
+ {{- content['text'] | trim }}
+ {%- endif %}
+ {%- endfor %}
+ {%- endif %}
+ {{- "<|eot|>" }}
+ {%- endif %}
+{%- endfor %}
+{%- if add_generation_prompt %}
+ {% generation %}{{- '<|header_start|>assistant<|header_end|>\n\n' }}{% endgeneration %}
+{%- endif %}""",
}
+# Global storage for custom template metadata
+_CUSTOM_INVERSE_PARSERS = {}
+_CUSTOM_MODEL_FAMILY_KEYWORDS = {}
+
+
+def add_chat_template(
+ template_name: str,
+ template: str,
+ inverse_parser: callable | None = None,
+ model_family_keywords: list[str] | None = None,
+) -> None:
+ r"""Add a custom chat template to the global template dictionary.
+
+ This function allows you to add custom chat templates for new model families
+ that support assistant token masking via the `{% generation %}` keyword.
+
+ Args:
+ template_name (str): The name of the template (e.g., "llama", "mistral").
+ This name will be used in the `chat_template_name` parameter of
+ `History.apply_chat_template()` and `History.from_text()`.
+ template (str): The Jinja2 template string. Must include `{% generation %}`
+ blocks around assistant message content to enable token masking.
+ inverse_parser (callable, optional): A function that parses formatted text back
+ into a History object. Should have signature `(text: str) -> History`.
+ If None, a basic parser will be used.
+ model_family_keywords (list[str], optional): Keywords to detect this model family
+ in the auto-detection logic. For example, ["llama", "meta-llama"] for Llama models.
+ If provided, the template will be automatically selected for models containing
+ these keywords in their name.
+
+ Example:
+ >>> from torchrl.data.llm.chat import add_chat_template, History
+ >>> from transformers import AutoTokenizer
+ >>>
+ >>> # Add a custom template for Llama models
+ >>> llama_template = '''
+ ... {% for message in messages %}
+ ... {%- if message['role'] == 'user' %}
+ ... {{ '[INST] ' + message['content'] + ' [/INST]' }}
+ ... {%- elif message['role'] == 'assistant' %}
+ ... {% generation %}{{ message['content'] + ' ' }}{% endgeneration %}
+ ... {%- endif %}
+ ... {% endfor %}
+ ... {%- if add_generation_prompt %}
+ ... {% generation %}{{ ' ' }}{% endgeneration %}
+ ... {%- endif %}
+ ... '''
+ >>>
+ >>> def parse_llama_text(text: str) -> History:
+ ... # Custom parser for Llama format
+ ... import re
+ ... pattern = r'\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?) '
+ ... matches = re.findall(pattern, text, re.DOTALL)
+ ... messages = []
+ ... for user_content, assistant_content in matches:
+ ... messages.append(History(role="user", content=user_content.strip()))
+ ... messages.append(History(role="assistant", content=assistant_content.strip()))
+ ... return lazy_stack(messages)
+ >>>
+ >>> # Add the template with auto-detection
+ >>> add_chat_template(
+ ... template_name="llama",
+ ... template=llama_template,
+ ... inverse_parser=parse_llama_text,
+ ... model_family_keywords=["llama", "meta-llama"]
+ ... )
+ >>>
+ >>> # Now you can use it with auto-detection
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+ >>> history = History.from_chats([[
+ ... {"role": "user", "content": "Hello"},
+ ... {"role": "assistant", "content": "Hi there!"}
+ ... ]])
+ >>>
+ >>> # Auto-detection will use the llama template
+ >>> result = history.apply_chat_template(
+ ... tokenizer=tokenizer,
+ ... add_generation_prompt=False,
+ ... return_dict=True,
+ ... return_assistant_tokens_mask=True,
+ ... )
+ >>>
+ >>> # Or use it explicitly
+ >>> result = history.apply_chat_template(
+ ... tokenizer=tokenizer,
+ ... chat_template_name="llama",
+ ... add_generation_prompt=False,
+ ... return_dict=True,
+ ... return_assistant_tokens_mask=True,
+ ... )
+
+ .. note:
+ - The template must include `{% generation %}` blocks around assistant message
+ content to enable assistant token masking.
+ - The inverse parser should handle the specific format of your template.
+ - Model family keywords are case-insensitive and matched against the tokenizer's
+ `name_or_path` attribute.
+ - Templates are stored globally and persist for the duration of the Python session.
+ """
+ global _CHAT_TEMPLATES, _CUSTOM_INVERSE_PARSERS, _CUSTOM_MODEL_FAMILY_KEYWORDS
+
+ # Validate template contains generation blocks
+ if "{% generation %}" not in template:
+ raise ValueError(
+ f"Template '{template_name}' must include '{{% generation %}}' blocks "
+ "around assistant message content to enable token masking."
+ )
+
+ # Add template to dictionary
+ _CHAT_TEMPLATES[template_name] = template
+
+ # Store inverse parser if provided
+ if inverse_parser is not None:
+ _CUSTOM_INVERSE_PARSERS[template_name] = inverse_parser
+
+ # Store model family keywords if provided
+ if model_family_keywords is not None:
+ _CUSTOM_MODEL_FAMILY_KEYWORDS[template_name] = model_family_keywords
+
+ torchrl_logger.info(
+ f"Added custom chat template '{template_name}' with assistant token masking support"
+ )
+
# We need the 'shadow' flag to avoid having tensordict complaining about 'type'/'size' etc. fields
class ContentBase(TensorClass["nocast", "shadow"]):
@@ -197,12 +371,93 @@ class History(TensorClass["nocast"]):
- Efficient methods to append, extend, and reshape history elements, enabling dynamic construction of conversation
trajectories, especially useful in reinforcement learning environments.
- Interoperability with the `transformers` API, allowing for easy tokenization and preparation of input data.
+ - **Assistant token masking support** across multiple model families for reinforcement learning applications.
+
+ **Recent Changes:**
+ - **ChatHistory Integration**: History objects are now used within :class:`~torchrl.modules.llm.policies.ChatHistory`
+ containers for structured conversation management in LLM environments.
+ - **Modular Wrapper Support**: Both vLLMWrapper and TransformersWrapper now use History objects when `input_mode="history"`
+ is specified, providing consistent conversation state management.
+ - **Environment Integration**: ChatEnv and related environments use History objects for state management and conversation tracking.
.. note:: The `""` role is used to indicate that the element is a placeholder,
for example when the tool call was not executed but a stack requires a certain number of elements
per batch to have congruent shapes. The :meth:`~torchrl.data.llm.chat.History.apply_chat_template`
method will remove the `` role from the history.
+ **Assistant Token Masking Support:**
+
+ The class supports assistant token masking across multiple model families, allowing you to identify which tokens
+ in a conversation were generated by the assistant. This is crucial for reinforcement learning applications.
+
+ **Supported Model Families:**
+
+ - **Qwen family** (e.g., `Qwen/Qwen2.5-0.5B`): Custom template with full tool calling support
+ - **DialoGPT family** (e.g., `microsoft/DialoGPT-medium`): Custom template for conversation format
+ - **Falcon family** (e.g., `tiiuae/falcon-7b-instruct`): Custom template for instruction format
+ - **DeepSeek family** (e.g., `deepseek-ai/deepseek-coder-6.7b-base`): Custom template with native format
+ - **Other models** (OPT, GPT, MPT, BLOOM, Pythia, Phi, etc.): Default `chatml_format` template
+
+ **Example with Assistant Token Masking:**
+
+ .. code-block:: python
+
+ >>> from torchrl.data.llm.chat import History
+ >>> from torchrl.modules.llm.policies import ChatHistory
+ >>> from transformers import AutoTokenizer
+ >>>
+ >>> # Create a conversation history
+ >>> history = History.from_chats([[
+ ... {"role": "user", "content": "Hello"},
+ ... {"role": "assistant", "content": "Hi there!"},
+ ... {"role": "user", "content": "How are you?"},
+ ... {"role": "assistant", "content": "I'm doing well, thanks!"}
+ ... ]])
+ >>>
+ >>> # Create ChatHistory container for LLM wrapper
+ >>> chat_history = ChatHistory(prompt=history)
+ >>>
+ >>> # Load any supported tokenizer
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+ >>>
+ >>> # Apply chat template with assistant token masking
+ >>> result = history.apply_chat_template(
+ ... tokenizer=tokenizer,
+ ... add_generation_prompt=False,
+ ... return_dict=True,
+ ... return_assistant_tokens_mask=True,
+ ... )
+ >>>
+ >>> # The result contains an assistant_masks tensor
+ >>> assistant_masks = result["assistant_masks"]
+ >>> print(f"Assistant tokens: {assistant_masks.sum().item()}")
+
+ **Integration with LLM Wrappers:**
+
+ History objects work seamlessly with the new modular wrapper design:
+
+ .. code-block:: python
+
+ >>> from torchrl.modules.llm import TransformersWrapper
+ >>> from torchrl.modules.llm.policies import ChatHistory
+ >>>
+ >>> # Create wrapper with history input mode
+ >>> wrapper = TransformersWrapper(
+ ... model, tokenizer=tokenizer,
+ ... input_mode="history",
+ ... generate=True,
+ ... return_log_probs=True
+ ... )
+ >>>
+ >>> # Use History with ChatHistory container
+ >>> history = History.from_chats([[
+ ... {"role": "user", "content": "Hello"},
+ ... {"role": "assistant", "content": "Hi there!"}
+ ... ]])
+ >>> chat_history = ChatHistory(prompt=history)
+ >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
+ >>> print(result["history"].response) # New response from LLM
+
Attributes:
role (str): The role of the message sender.
content (str): The content of the message.
@@ -256,6 +511,10 @@ class History(TensorClass["nocast"]):
<|im_start|>assistant
+ .. seealso::
+ :class:`~torchrl.modules.llm.policies.ChatHistory`: Container for managing conversation data in LLM environments.
+ :class:`~torchrl.modules.llm.policies.Text`: Container for text data.
+ :class:`~torchrl.modules.llm.policies.Tokens`: Container for token data.
"""
role: str
@@ -277,7 +536,7 @@ def apply_chat_template(
tokenizer: transformers.AutoTokenizer | transformers.AutoProcessor, # noqa
add_generation_prompt: bool = True,
chat_template: str | None = None,
- chat_template_name: Literal["chatml_format", "qwen"] | None = None,
+ chat_template_name: str | None = None,
continue_final_message: bool = False,
tokenize: bool | None = None,
padding: bool | str = False,
@@ -286,15 +545,16 @@ def apply_chat_template(
return_dict: bool | None = None,
return_assistant_tokens_mask: bool = False,
**kwargs,
- ):
+ ) -> str | list[str] | TensorDict:
"""Applies a chat template to the history.
Keyword Args:
tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
add_generation_prompt (bool, optional): Whether to add a generation prompt (e.g. `"<|im_start|>assistant"`). Defaults to `True`.
chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
- chat_template_name (Literal["chatml_format", "qwen"], optional): The name of the chat template to use.
- Prevalent over `tokenizer.chat_template`. Defaults to `None`.
+ chat_template_name (str, optional): The name of the chat template to use.
+ Prevalent over `tokenizer.chat_template`. If `None`, the method will automatically detect the model family and use the appropriate template.
+ Defaults to `None`.
continue_final_message (bool, optional): Whether to continue the final message. Defaults to `False`.
tokenize (bool, optional): Whether to tokenize the output. Defaults to `False`.
padding (bool | str, optional): The padding strategy to use. Defaults to `False`.
@@ -308,9 +568,14 @@ def apply_chat_template(
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
Defaults to `False`.
- .. note:: By default, the `"qwen"` chat template does not support this functionality. A modified version of the template
- can be used by setting `chat_template_name="qwen"`, which will override the default template from the tokenizer.
- For other tokenizers, similar edits can be made to the template and passed to the method via the `chat_template` argument.
+ .. note:: Assistant token masking is supported across multiple model families:
+ - **Qwen family**: Uses custom template with full tool calling support
+ - **DialoGPT family**: Uses custom template for conversation format
+ - **Falcon family**: Uses custom template for instruction format
+ - **DeepSeek family**: Uses custom template with native format
+ - **Other models**: Use the default `chatml_format` template
+
+ The method automatically detects the model family and selects the appropriate template.
**kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
@@ -325,13 +590,54 @@ def apply_chat_template(
raise RuntimeError(
"You must specify a tokenizer to use when chat_template is not specified."
)
- elif "qwen" in getattr(tokenizer, "name_or_path", "").lower():
- # We prefer our implementation of the Qwen template,
- # since it accounts for the assistant's masking.
- chat_template = _CHAT_TEMPLATES["qwen"]
- chat_template_name = None
else:
- chat_template = tokenizer.chat_template
+ # Auto-detect model family and use appropriate template
+ model_name = getattr(tokenizer, "name_or_path", "").lower()
+
+ # First check for custom model family keywords
+ custom_template_found = False
+ for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items():
+ if any(keyword.lower() in model_name for keyword in keywords):
+ chat_template = _CHAT_TEMPLATES[template_name]
+ chat_template_name = None
+ custom_template_found = True
+ break
+
+ if not custom_template_found:
+ # Fall back to built-in model family detection
+ if "qwen" in model_name:
+ # We prefer our implementation of the Qwen template,
+ # since it accounts for the assistant's masking.
+ chat_template = _CHAT_TEMPLATES["qwen"]
+ chat_template_name = None
+ elif "dialogpt" in model_name or "microsoft/dialo" in model_name:
+ # DialoGPT family - use our custom template
+ chat_template = _CHAT_TEMPLATES["dialogpt"]
+ chat_template_name = None
+ elif "falcon" in model_name or "tiiuae/falcon" in model_name:
+ # Falcon family - use our custom template
+ chat_template = _CHAT_TEMPLATES["falcon"]
+ chat_template_name = None
+ elif "deepseek" in model_name:
+ # DeepSeek family - use our custom template with generation keyword
+ chat_template = _CHAT_TEMPLATES["deepseek"]
+ chat_template_name = None
+ elif "llama" in model_name:
+ # Llama family - use our custom template
+ chat_template = _CHAT_TEMPLATES["llama"]
+ chat_template_name = None
+ else:
+ # For other models, check if their default template supports generation
+ if (
+ hasattr(tokenizer, "chat_template")
+ and tokenizer.chat_template
+ and "{% generation %}" in tokenizer.chat_template
+ ):
+ # Use the model's own template if it supports generation
+ chat_template = tokenizer.chat_template
+ else:
+ # Use our default chatml_format template
+ chat_template = _CHAT_TEMPLATES["chatml_format"]
if chat_template is None:
chat_template = _CHAT_TEMPLATES["chatml_format"]
if tokenize is None:
@@ -402,26 +708,65 @@ def apply_chat_template(
def from_text(
cls,
text: str | list[str],
- chat_template_name: Literal["chatml_format", "qwen"] | None = None,
+ chat_template_name: str | None = None,
+ # currently without effect
chat_template: str | None = None,
tokenizer: transformers.AutoTokenizer # noqa: F821
| transformers.AutoProcessor # noqa: F821
| None = None,
) -> History:
- if chat_template_name is None and chat_template is None:
- if "qwen" in getattr(tokenizer, "name_or_path", "").lower():
- # We can automatically detect the template name from the tokenizer
- # and use the precoded parser.
- chat_template_name = "qwen"
- else:
- chat_template_name = "chatml_format"
- elif chat_template_name in ("chatml_format",):
+ if chat_template_name is None:
+ if chat_template is not None:
+ # TODO: find best match given template
+ pass
+
+ model_name = getattr(tokenizer, "name_or_path", "").lower()
+ # First check for custom model family keywords
+ custom_template_found = False
+ for template_name, keywords in _CUSTOM_MODEL_FAMILY_KEYWORDS.items():
+ if any(keyword.lower() in model_name for keyword in keywords):
+ chat_template_name = template_name
+ custom_template_found = True
+ break
+
+ if not custom_template_found:
+ # Fall back to built-in model family detection
+ if "qwen" in model_name:
+ # We can automatically detect the template name from the tokenizer
+ # and use the precoded parser.
+ chat_template_name = "qwen"
+ elif "dialogpt" in model_name or "microsoft/dialo" in model_name:
+ chat_template_name = "dialogpt"
+ elif "falcon" in model_name or "tiiuae/falcon" in model_name:
+ chat_template_name = "falcon"
+ elif "deepseek" in model_name:
+ chat_template_name = "deepseek"
+ elif "llama" in model_name:
+ chat_template_name = "llama"
+ else:
+ chat_template_name = "chatml_format"
+
+ # Get the appropriate inverse parser function
+ if chat_template_name in ("chatml_format",):
func = cls._inv_chatml
elif chat_template_name in ("qwen",):
func = cls._inv_qwen
+ elif chat_template_name in ("dialogpt",):
+ func = cls._inv_dialogpt
+ elif chat_template_name in ("falcon",):
+ func = cls._inv_falcon
+ elif chat_template_name in ("deepseek",):
+ func = cls._inv_deepseek
+ elif chat_template_name in ("llama",):
+ func = cls._inv_llama
+ elif chat_template_name in _CUSTOM_INVERSE_PARSERS:
+ # Use custom inverse parser
+ func = _CUSTOM_INVERSE_PARSERS[chat_template_name]
else:
raise NotImplementedError(
- "chat_template_name must be one of ('chatml_format', 'qwen')"
+ f"chat_template_name '{chat_template_name}' is not supported. "
+ "Supported templates: 'chatml_format', 'qwen', 'dialogpt', 'falcon', 'deepseek'. "
+ "Use add_chat_template() to add custom templates."
)
if isinstance(text, list):
list_of_histories = [func(t) for t in text]
@@ -598,6 +943,218 @@ def _inv_qwen(cls, template):
return lazy_stack(parsed_messages)
+ @classmethod
+ def _inv_dialogpt(cls, text: str) -> History:
+ """Inverts a DialogPT string into a History object.
+
+ Args:
+ text (str): The DialogPT string to invert.
+
+ Returns:
+ History: The inverted History object.
+ """
+ torchrl_logger.debug(f"Inverting DialogPT:\n{text}")
+
+ # DialogPT format is simple: alternating user/assistant messages
+ # Split by lines and parse
+ lines = text.strip().split("\n")
+ parsed_messages = []
+
+ for line in lines:
+ line = line.strip()
+ if not line:
+ continue
+
+ # Determine role based on content
+ if line.startswith("Assistant:"):
+ role = "assistant"
+ content = line[len("Assistant:") :].strip()
+ elif line.startswith("User:"):
+ role = "user"
+ content = line[len("User:") :].strip()
+ else:
+ # Default to user if no prefix
+ role = "user"
+ content = line
+
+ message_dict = {
+ "role": role,
+ "content": content,
+ "is_complete": True, # DialogPT doesn't have explicit end tokens
+ "tool_calls": None,
+ "tool_responses": None,
+ }
+
+ parsed_messages.append(cls(**message_dict))
+
+ if not parsed_messages:
+ raise RuntimeError(f"Couldn't get a single item out of text {text}.")
+
+ return lazy_stack(parsed_messages)
+
+ @classmethod
+ def _inv_falcon(cls, text: str) -> History:
+ """Inverts a Falcon string into a History object.
+
+ Args:
+ text (str): The Falcon string to invert.
+
+ Returns:
+ History: The inverted History object.
+ """
+ torchrl_logger.debug(f"Inverting Falcon:\n{text}")
+
+ # Falcon format: "User: ... Assistant: ..."
+ # Split by "User:" and "Assistant:" prefixes
+ import re
+
+ # Pattern to match User: and Assistant: messages
+ pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)"
+ matches = re.findall(pattern, text, re.DOTALL)
+
+ parsed_messages = []
+ for match in matches:
+ if len(match) != 2:
+ continue
+ prefix, content = match
+ content = content.strip()
+ if not content:
+ continue
+
+ if prefix == "User:":
+ role = "user"
+ elif prefix == "Assistant:":
+ role = "assistant"
+ else:
+ continue
+
+ message_dict = {
+ "role": role,
+ "content": content,
+ "is_complete": True, # Falcon doesn't have explicit end tokens
+ "tool_calls": None,
+ "tool_responses": None,
+ }
+
+ parsed_messages.append(cls(**message_dict))
+
+ if not parsed_messages:
+ raise RuntimeError(f"Couldn't get a single item out of text {text}.")
+
+ return lazy_stack(parsed_messages)
+
+ @classmethod
+ def _inv_deepseek(cls, text: str) -> History:
+ """Inverts a DeepSeek string into a History object.
+
+ Args:
+ text (str): The DeepSeek string to invert.
+
+ Returns:
+ History: The inverted History object.
+ """
+ torchrl_logger.debug(f"Inverting DeepSeek:\n{text}")
+ import re
+
+ # Remove leading/trailing special tokens (e.g.
+ text = re.sub(r"^<[^>]+>", "", text) # Remove leading <...>
+ text = re.sub(r"<[^>]+>$", "", text) # Remove trailing <...>
+ # Remove any REDACTED_SPECIAL_TOKEN if present
+ text = re.sub(r"REDACTED_SPECIAL_TOKEN", "", text)
+ # Pattern to match User: and Assistant: messages
+ pattern = r"(User:|Assistant:)\s*(.*?)(?=(User:|Assistant:)|$)"
+ matches = re.findall(pattern, text, re.DOTALL)
+ parsed_messages = []
+ for match in matches:
+ if len(match) < 2:
+ continue
+ prefix, content = match[0], match[1]
+ content = content.strip()
+ if not content:
+ continue
+ if prefix == "User:":
+ role = "user"
+ elif prefix == "Assistant:":
+ role = "assistant"
+ else:
+ continue
+ message_dict = {
+ "role": role,
+ "content": content,
+ "is_complete": True, # DeepSeek doesn't have explicit end tokens
+ "tool_calls": None,
+ "tool_responses": None,
+ }
+ parsed_messages.append(cls(**message_dict))
+ if not parsed_messages:
+ raise RuntimeError(f"Couldn't get a single item out of text {text}.")
+ return lazy_stack(parsed_messages)
+
+ @classmethod
+ def _inv_llama(cls, text: str) -> History:
+ import re
+
+ messages = []
+
+ # Remove BOS token if present
+ if text.startswith("<|begin_of_text|>"):
+ text = text[len("<|begin_of_text|>") :]
+
+ # Pattern to match complete message blocks: <|header_start|>role<|header_end|>\n\ncontent<|eot|>
+ complete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)<\|eot\|>"
+ complete_matches = re.findall(complete_pattern, text, re.DOTALL)
+
+ # Pattern to match incomplete message blocks: <|header_start|>role<|header_end|>\n\ncontent (without <|eot|>)
+ incomplete_pattern = r"<\|header_start\|>(\w+)<\|header_end\|>\n\n(.*?)$"
+
+ # Find any incomplete message at the end
+ incomplete_matches = []
+ if complete_matches:
+ # Look for incomplete message after the last complete one
+ last_complete_end = text.rfind("<|eot|>")
+ if last_complete_end != -1:
+ remaining_text = text[last_complete_end + len("<|eot|>") :]
+ if remaining_text.strip():
+ incomplete_match = re.search(
+ incomplete_pattern, remaining_text, re.DOTALL
+ )
+ if incomplete_match:
+ incomplete_matches = [
+ (
+ incomplete_match.group(1),
+ incomplete_match.group(2),
+ False,
+ )
+ ]
+ else:
+ # No complete messages, check entire text for incomplete message
+ incomplete_match = re.search(incomplete_pattern, text, re.DOTALL)
+ if incomplete_match:
+ incomplete_matches = [
+ (incomplete_match.group(1), incomplete_match.group(2), False)
+ ]
+
+ # Process complete messages
+ for role, content in complete_matches:
+ if content.strip():
+ messages.append(
+ cls(role=role, content=content.strip(), is_complete=True)
+ )
+
+ # Process incomplete messages
+ for role, content, is_complete in incomplete_matches:
+ if content.strip():
+ messages.append(
+ cls(role=role, content=content.strip(), is_complete=is_complete)
+ )
+
+ if not messages:
+ raise RuntimeError(f"Couldn't parse Llama format from text: {text}")
+
+ from tensordict import lazy_stack
+
+ return lazy_stack(messages)
+
def append(
self, history: History, *, inplace: bool = True, dim: int = -1
) -> History:
diff --git a/torchrl/data/replay_buffers/ray_buffer.py b/torchrl/data/replay_buffers/ray_buffer.py
index 7ab7ce2958c..e9e7ad23812 100644
--- a/torchrl/data/replay_buffers/ray_buffer.py
+++ b/torchrl/data/replay_buffers/ray_buffer.py
@@ -163,6 +163,10 @@ def _replay_lock(self):
"""
return contextlib.nullcontext()
+ @property
+ def batch_size(self):
+ return ray.get(self._rb._getattr.remote("_batch_size"))
+
def sample(self, *args, **kwargs):
pending_task = self._rb.sample.remote(*args, **kwargs)
return ray.get(pending_task)
@@ -196,8 +200,8 @@ def loads(self, path):
def load(self, *args, **kwargs):
return ray.get(self._rb.load.remote(*args, **kwargs))
- def empty(self):
- return ray.get(self._rb.empty.remote())
+ def empty(self, empty_write_count: bool = True):
+ return ray.get(self._rb.empty.remote(empty_write_count=empty_write_count))
def __getitem__(self, index):
return ray.get(self._rb.__getitem__.remote(index))
diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py
index 93dc73d2fb1..cb097bb0097 100644
--- a/torchrl/data/replay_buffers/replay_buffers.py
+++ b/torchrl/data/replay_buffers/replay_buffers.py
@@ -383,6 +383,17 @@ def set_rng(self, generator):
def dim_extend(self):
return self._dim_extend
+ @property
+ def batch_size(self):
+ """The batch size of the replay buffer.
+
+ The batch size can be overriden by setting the `batch_size` parameter in the :meth:`sample` method.
+
+ It defines both the number of samples returned by :meth:`sample` and the number of samples that are
+ yielded by the :class:`ReplayBuffer` iterator.
+ """
+ return self._batch_size
+
@dim_extend.setter
def dim_extend(self, value):
if (
@@ -783,9 +794,13 @@ def _sample(self, batch_size: int) -> tuple[Any, dict]:
return data, info
- def empty(self):
- """Empties the replay buffer and reset cursor to 0."""
- self._writer._empty()
+ def empty(self, empty_write_count: bool = True):
+ """Empties the replay buffer and reset cursor to 0.
+
+ Args:
+ empty_write_count (bool, optional): Whether to empty the write_count attribute. Defaults to `True`.
+ """
+ self._writer._empty(empty_write_count=empty_write_count)
self._sampler._empty()
self._storage._empty()
diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py
index 245e0b55913..547534652c2 100644
--- a/torchrl/data/replay_buffers/writers.py
+++ b/torchrl/data/replay_buffers/writers.py
@@ -58,7 +58,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
...
@abstractmethod
- def _empty(self):
+ def _empty(self, empty_write_count: bool = True) -> None:
...
@abstractmethod
@@ -122,7 +122,7 @@ def add(self, data: Any) -> int:
def extend(self, data: Sequence) -> torch.Tensor:
raise RuntimeError(self.WRITING_ERR)
- def _empty(self):
+ def _empty(self, empty_write_count: bool = True) -> None:
raise RuntimeError(self.WRITING_ERR)
def dumps(self, path):
@@ -189,7 +189,7 @@ def extend(self, data: Sequence) -> torch.Tensor:
else:
batch_size = len(tree_leaves(data)[0])
if batch_size == 0:
- raise RuntimeError("Expected at least one element in extend.")
+ raise RuntimeError(f"Expected at least one element in extend. Got {data=}")
device = data.device if hasattr(data, "device") else None
max_size_along0 = self._storage._max_size_along_dim0(batched_data=data)
index = (
@@ -215,9 +215,10 @@ def state_dict(self) -> dict[str, Any]:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self._cursor = state_dict["_cursor"]
- def _empty(self):
+ def _empty(self, empty_write_count: bool = True) -> None:
self._cursor = 0
- self._write_count = 0
+ if empty_write_count:
+ self._write_count = 0
@property
def _cursor(self):
@@ -572,9 +573,11 @@ def extend(self, data: TensorDictBase) -> None:
ent.mark_update(index)
return index
- def _empty(self) -> None:
+ def _empty(self, empty_write_count: bool = True) -> None:
self._cursor = 0
self._current_top_values = []
+ if empty_write_count:
+ self._write_count = 0
def __getstate__(self):
if get_spawning_popen() is not None:
@@ -664,7 +667,7 @@ def _rng(self, value):
for writer in self._writers:
writer._rng = value
- def _empty(self):
+ def _empty(self, empty_write_count: bool = True) -> None:
raise NotImplementedError
def dumps(self, path: Path):
diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py
index b9546a485a9..02e1572aec1 100644
--- a/torchrl/data/tensor_specs.py
+++ b/torchrl/data/tensor_specs.py
@@ -36,6 +36,7 @@
import torch
from tensordict import (
is_tensor_collection,
+ lazy_stack,
LazyStackedTensorDict,
NonTensorData,
NonTensorStack,
@@ -2747,6 +2748,13 @@ class NonTensor(TensorSpec):
batched (bool, optional): Indicates whether the data is batched. If `True`, the `rand`, `zero`, and `one` methods
will generate data with an additional batch dimension, stacking copies of the `example_data` across this dimension.
Defaults to `False`.
+ Exclusive with `feature_dims`.
+ feature_dims (int, optional): The number of dimensions that are features.
+ The feature dimensions are the trailing dimensions that are not batch dimensions.
+ Every feature dimension is included in a single NonTensorData object, whereas these
+ are stacked across the batch dimension.
+ Exclusive with `batched`.
+ Defaults to `None` (all if batched=False, none if batched=True).
**kwargs: Additional keyword arguments passed to the parent class.
.. seealso:: :class:`~torchrl.data.Choice` which allows to randomly choose among different specs when calling
@@ -2773,7 +2781,8 @@ def __init__(
device: DEVICE_TYPING | None = None,
dtype: torch.dtype | None = None,
example_data: Any = None,
- batched: bool = False,
+ batched: bool | None = None,
+ feature_dims: int | None = None,
**kwargs,
):
if isinstance(shape, int):
@@ -2784,7 +2793,17 @@ def __init__(
shape=shape, space=None, device=device, dtype=dtype, domain=domain, **kwargs
)
self.example_data = example_data
+ if batched is None and feature_dims is None:
+ batched = False
+ feature_dims = len(self.shape)
+ elif batched is None and feature_dims is not None:
+ batched = False
+ elif batched is not None and feature_dims is not None:
+ raise ValueError("Cannot specify both batched and feature_dims.")
+ else:
+ feature_dims = 0 if batched else len(self.shape)
self.batched = batched
+ self.feature_dims = feature_dims
self.encode = self._encode_eager
def __repr__(self):
@@ -2835,7 +2854,7 @@ def to(self, dest: torch.dtype | DEVICE_TYPING) -> NonTensor:
device=dest_device,
dtype=None,
example_data=self.example_data,
- batched=self.batched,
+ feature_dims=self.feature_dims,
)
def clone(self) -> NonTensor:
@@ -2844,29 +2863,32 @@ def clone(self) -> NonTensor:
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
- batched=self.batched,
+ feature_dims=self.feature_dims,
)
def rand(self, shape=None):
if shape is None:
shape = ()
if self.batched:
- with set_capture_non_tensor_stack(False):
- val = NonTensorData(
- data=self.example_data,
- batch_size=(),
- device=self.device,
- )
- shape = (*shape, *self._safe_shape)
- if shape:
- for i in shape:
- val = torch.stack([val.copy() for _ in range(i)], -1)
- return val
- return NonTensorData(
- data=self.example_data,
- batch_size=(*shape, *self._safe_shape),
- device=self.device,
- )
+ # feature dim is None
+ feature_dims = 0
+ else:
+ feature_dims = self.feature_dims
+ if isinstance(shape, int):
+ shape = _size([shape])
+ total_shape = (*shape, *self._safe_shape)
+ batch_shape = total_shape[:-feature_dims]
+ feature_shape = total_shape[-feature_dims:]
+ with set_capture_non_tensor_stack(False):
+ val = NonTensorData(
+ data=self.example_data,
+ batch_size=feature_shape,
+ device=self.device,
+ )
+ if batch_shape:
+ for i in reversed(batch_shape):
+ val = lazy_stack([val.copy() for _ in range(i)])
+ return val
def zero(self, shape=None):
return self.rand(shape=shape)
@@ -2877,10 +2899,18 @@ def one(self, shape=None):
def is_in(self, val: Any) -> bool:
if not isinstance(val, torch.Tensor) and not is_tensor_collection(val):
return True
- shape = torch.broadcast_shapes(self._safe_shape, val.shape)
+ # Since we don't really share Nontensor across processes, it's ok to modify the shape
+ # We do this when the shape has been determined by a single sample gathered
+ # from a dataloader, but shapes of the non-tensor may actually vary.
+ if any(v < 0 for v in val.shape):
+ self.shape = torch.Size(
+ (self.shape[i] if s >= 0 else -1 for i, s in enumerate(val.shape))
+ )
+ _safe_val_shape = torch.Size(s if s >= 0 else 1 for s in val.shape)
+ shape = torch.broadcast_shapes(self._safe_shape, _safe_val_shape)
return (
is_non_tensor(val)
- and val.shape == shape
+ and _safe_val_shape == shape
# We relax constrains on device as they're hard to enforce for non-tensor
# tensordicts and pointless
# and val.device == self.device
@@ -2904,18 +2934,20 @@ def expand(self, *shape):
device=self.device,
dtype=None,
example_data=self.example_data,
- batched=self.batched,
+ feature_dims=self.feature_dims,
)
def unsqueeze(self, dim: int) -> NonTensor:
unsq = super().unsqueeze(dim=dim)
unsq.example_data = self.example_data
+ unsq.feature_dims = self.feature_dims
unsq.batched = self.batched
return unsq
def squeeze(self, dim: int | None = None) -> NonTensor:
sq = super().squeeze(dim=dim)
sq.example_data = self.example_data
+ sq.feature_dims = self.feature_dims
sq.batched = self.batched
return sq
@@ -2925,7 +2957,7 @@ def _reshape(self, shape):
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
- batched=self.batched,
+ feature_dims=self.feature_dims,
)
def _unflatten(self, dim, sizes):
@@ -2935,7 +2967,7 @@ def _unflatten(self, dim, sizes):
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
- batched=self.batched,
+ feature_dims=self.feature_dims,
)
def __getitem__(self, idx: SHAPE_INDEX_TYPING):
@@ -2946,7 +2978,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING):
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
- batched=self.batched,
+ feature_dims=self.feature_dims,
)
def unbind(self, dim: int = 0):
@@ -2964,7 +2996,7 @@ def unbind(self, dim: int = 0):
device=self.device,
dtype=self.dtype,
example_data=self.example_data,
- batched=self.batched,
+ feature_dims=self.feature_dims,
)
for i in range(self.shape[dim])
)
@@ -2980,7 +3012,12 @@ def _encode_eager(
*,
ignore_device: bool = False,
) -> torch.Tensor | TensorDictBase:
- return NonTensorData(val, device=self.device, batch_size=self.shape)
+ return NonTensorData(
+ val,
+ device=self.device,
+ batch_size=self.shape,
+ feature_dims=self.feature_dims,
+ )
class _UnboundedMeta(abc.ABCMeta):
@@ -4969,6 +5006,9 @@ class Composite(TensorSpec):
to the batch-size of the corresponding tensordicts.
data_cls (type, optional): the tensordict subclass (TensorDict, TensorClass, tensorclass...) that should be
enforced in the env. Defaults to ``None``.
+ step_mdp_static (bool, optional): whether the spec is static under step_mdp. Defaults to ``False``.
+ Defining a `Composite` as a step_mdp_static spec will make it so that the entire related TensorDict/TensorClass
+ instance is copied during calls to `step_mdp` - and not updated in-place.
Examples:
>>> pixels_spec = Bounded(
@@ -5044,6 +5084,7 @@ def __init__(
shape: tuple | torch.Size | None = None,
device: torch.device | None = None,
data_cls: type | None = None,
+ step_mdp_static: bool = False,
**kwargs,
):
# For compatibility with TensorDict
@@ -5057,6 +5098,7 @@ def __init__(
shape = _size(())
self._shape = _size(shape)
self._specs = {}
+ self.step_mdp_static = step_mdp_static
_device = (
_make_ordinal_device(torch.device(device)) if device is not None else device
@@ -5548,6 +5590,7 @@ def keys(
leaves_only: bool = False,
*,
is_leaf: Callable[[type], bool] | None = None,
+ step_mdp_static_only: bool = False,
) -> _CompositeSpecKeysView: # noqa: D417
"""Keys of the Composite.
@@ -5568,6 +5611,8 @@ def keys(
is_leaf (callable, optional): reads a type and returns a boolean indicating if that type
should be seen as a leaf. By default, all non-Composite nodes are considered as
leaves.
+ step_mdp_static_only (bool, optional): if ``True``, only keys that are static under step_mdp will be returned.
+ Default is ``False``.
"""
return _CompositeSpecItemsView(
@@ -5575,6 +5620,7 @@ def keys(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
+ step_mdp_static_only=step_mdp_static_only,
)._keys()
def items(
@@ -5583,6 +5629,7 @@ def items(
leaves_only: bool = False,
*,
is_leaf: Callable[[type], bool] | None = None,
+ step_mdp_static_only: bool = False,
) -> _CompositeSpecItemsView: # noqa: D417
"""Items of the Composite.
@@ -5601,12 +5648,15 @@ def items(
is_leaf (callable, optional): reads a type and returns a boolean indicating if that type
should be seen as a leaf. By default, all non-Composite nodes are considered as
leaves.
+ step_mdp_static_only (bool, optional): if ``True``, only keys that are static under step_mdp will be returned.
+ Default is ``False``.
"""
return _CompositeSpecItemsView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
+ step_mdp_static_only=step_mdp_static_only,
)
def values(
@@ -5615,6 +5665,7 @@ def values(
leaves_only: bool = False,
*,
is_leaf: Callable[[type], bool] | None = None,
+ step_mdp_static_only: bool = False,
) -> _CompositeSpecValuesView: # noqa: D417
"""Values of the Composite.
@@ -5633,24 +5684,31 @@ def values(
is_leaf (callable, optional): reads a type and returns a boolean indicating if that type
should be seen as a leaf. By default, all non-Composite nodes are considered as
leaves.
+ step_mdp_static_only (bool, optional): if ``True``, only keys that are static under step_mdp will be returned.
+ Default is ``False``.
"""
return _CompositeSpecItemsView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
+ step_mdp_static_only=step_mdp_static_only,
)._values()
- def _reshape(self, shape):
+ def _reshape(self, shape: torch.Size) -> Composite:
_specs = {
key: val.reshape((*shape, *val.shape[self.ndimension() :]))
for key, val in self._specs.items()
}
return self.__class__(
- _specs, shape=shape, device=self.device, data_cls=self.data_cls
+ _specs,
+ shape=shape,
+ device=self.device,
+ data_cls=self.data_cls,
+ step_mdp_static=self.step_mdp_static,
)
- def _unflatten(self, dim, sizes):
+ def _unflatten(self, dim: int, sizes: tuple[int, ...]) -> Composite:
shape = torch.zeros(self.shape, device="meta").unflatten(dim, sizes).shape
return self._reshape(shape)
@@ -5669,7 +5727,11 @@ def to(self, dest: torch.dtype | DEVICE_TYPING) -> Composite:
continue
kwargs[key] = value.to(dest)
return self.__class__(
- **kwargs, device=self.device, shape=self.shape, data_cls=self.data_cls
+ **kwargs,
+ device=self.device,
+ shape=self.shape,
+ data_cls=self.data_cls,
+ step_mdp_static=self.step_mdp_static,
)
if not isinstance(dest, (str, int, torch.device)):
raise ValueError(
@@ -5687,7 +5749,11 @@ def to(self, dest: torch.dtype | DEVICE_TYPING) -> Composite:
continue
kwargs[key] = value.to(dest)
return self.__class__(
- **kwargs, device=_device, shape=self.shape, data_cls=self.data_cls
+ **kwargs,
+ device=_device,
+ shape=self.shape,
+ data_cls=self.data_cls,
+ step_mdp_static=self.step_mdp_static,
)
def clone(self) -> Composite:
@@ -5707,6 +5773,7 @@ def clone(self) -> Composite:
device=device,
shape=self.shape,
data_cls=self.data_cls,
+ step_mdp_static=self.step_mdp_static,
)
def cardinality(self) -> int:
@@ -5754,7 +5821,7 @@ def enumerate(self, use_mask: bool = False) -> TensorDictBase:
samples = cls.from_dict({}, batch_size=self.shape, device=self.device)
return samples
- def empty(self):
+ def empty(self) -> Composite:
"""Create a spec like self, but with no entries."""
try:
device = self.device
@@ -5765,6 +5832,7 @@ def empty(self):
device=device,
shape=self.shape,
data_cls=self.data_cls,
+ step_mdp_static=self.step_mdp_static,
)
def to_numpy(self, val: TensorDict, safe: bool | None = None) -> dict:
@@ -5793,7 +5861,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase:
device=device,
)
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
return (
type(self) == type(other)
and self.shape == other.shape
@@ -5828,7 +5896,7 @@ def update(self, dict_or_spec: Composite | dict[str, TensorSpec]) -> None:
self[key] = item
return self
- def expand(self, *shape):
+ def expand(self, *shape: tuple[int, ...] | torch.Size) -> Composite:
if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)):
shape = shape[0]
if any(s1 != s2 and s2 != 1 for s1, s2 in zip(shape[-self.ndim :], self.shape)):
@@ -5851,10 +5919,11 @@ def expand(self, *shape):
shape=shape,
device=device,
data_cls=self.data_cls,
+ step_mdp_static=self.step_mdp_static,
)
return out
- def squeeze(self, dim: int | None = None):
+ def squeeze(self, dim: int | None = None) -> Composite:
if dim is not None:
if dim < 0:
dim += len(self.shape)
@@ -5873,6 +5942,7 @@ def squeeze(self, dim: int | None = None):
shape=shape,
device=device,
data_cls=self.data_cls,
+ step_mdp_static=self.step_mdp_static,
)
if self.shape.count(1) == 0:
@@ -5884,7 +5954,7 @@ def squeeze(self, dim: int | None = None):
out = self.squeeze(self.shape.index(1))
return out.squeeze()
- def unsqueeze(self, dim: int):
+ def unsqueeze(self, dim: int) -> Composite:
if dim < 0:
dim += len(self.shape) + 1
@@ -5903,9 +5973,10 @@ def unsqueeze(self, dim: int):
shape=shape,
device=device,
data_cls=self.data_cls,
+ step_mdp_static=self.step_mdp_static,
)
- def unbind(self, dim: int = 0):
+ def unbind(self, dim: int = 0) -> tuple[Composite, ...]:
orig_dim = dim
if dim < 0:
dim = len(self.shape) + dim
@@ -5921,6 +5992,7 @@ def unbind(self, dim: int = 0):
shape=shape,
device=self.device,
data_cls=self.data_cls,
+ step_mdp_static=self.step_mdp_static,
)
for i in range(self.shape[dim])
)
@@ -5937,14 +6009,14 @@ def is_locked(self, value: bool) -> None:
else:
self.unlock_()
- def __getstate__(self):
+ def __getstate__(self) -> dict:
result = self.__dict__.copy()
__lock_parents_weakrefs = result.pop("__lock_parents_weakrefs", None)
if __lock_parents_weakrefs is not None:
result["_lock_recurse"] = True
return result
- def __setstate__(self, state):
+ def __setstate__(self, state: dict) -> None:
_lock_recurse = state.pop("_lock_recurse", False)
for key, value in state.items():
setattr(self, key, value)
@@ -5953,8 +6025,12 @@ def __setstate__(self, state):
self.lock_(recurse=_lock_recurse)
def _propagate_lock(
- self, *, recurse: bool, lock_parents_weakrefs=None, is_compiling
- ):
+ self,
+ *,
+ recurse: bool,
+ lock_parents_weakrefs: list[weakref.ref] | None = None,
+ is_compiling: bool,
+ ) -> None:
"""Registers the parent composite that handles the lock."""
self._is_locked = True
if lock_parents_weakrefs is not None:
@@ -5984,7 +6060,7 @@ def _propagate_lock(
)
@property
- def _lock_parents_weakrefs(self):
+ def _lock_parents_weakrefs(self) -> list[weakref.ref]:
_lock_parents_weakrefs = self.__dict__.get("__lock_parents_weakrefs")
if _lock_parents_weakrefs is None:
self.__dict__["__lock_parents_weakrefs"] = []
@@ -5992,10 +6068,10 @@ def _lock_parents_weakrefs(self):
return _lock_parents_weakrefs
@_lock_parents_weakrefs.setter
- def _lock_parents_weakrefs(self, value: list):
+ def _lock_parents_weakrefs(self, value: list[weakref.ref]) -> None:
self.__dict__["__lock_parents_weakrefs"] = value
- def lock_(self, recurse: bool | None = None) -> T:
+ def lock_(self, recurse: bool | None = None) -> None:
"""Locks the Composite and prevents modification of its content.
The recurse argument control whether the lock will be propagated to sub-specs.
@@ -6045,7 +6121,7 @@ def lock_(self, recurse: bool | None = None) -> T:
self._propagate_lock(recurse=recurse, is_compiling=is_comp)
return self
- def _propagate_unlock(self, recurse: bool):
+ def _propagate_unlock(self, recurse: bool) -> list[Composite]:
# if we end up here, we can clear the graph associated with this td
self._is_locked = False
@@ -6061,7 +6137,7 @@ def _propagate_unlock(self, recurse: bool):
return sub_specs
return []
- def _check_unlock(self, first_attempt=True):
+ def _check_unlock(self, first_attempt: bool = True) -> None:
if not first_attempt:
gc.collect()
obj = None
@@ -6208,12 +6284,14 @@ def keys(
leaves_only: bool = False,
*,
is_leaf: Callable[[type], bool] | None = None,
+ step_mdp_static_only: bool = False,
) -> _CompositeSpecKeysView:
return _CompositeSpecItemsView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
+ step_mdp_static_only=step_mdp_static_only,
)._keys()
def items(
@@ -6222,6 +6300,7 @@ def items(
leaves_only: bool = False,
*,
is_leaf: Callable[[type], bool] | None = None,
+ step_mdp_static_only: bool = False,
) -> _CompositeSpecItemsView:
return list(
_CompositeSpecItemsView(
@@ -6229,6 +6308,7 @@ def items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
+ step_mdp_static_only=step_mdp_static_only,
)
)
@@ -6238,12 +6318,14 @@ def values(
leaves_only: bool = False,
*,
is_leaf: Callable[[type], bool] | None = None,
+ step_mdp_static_only: bool = False,
) -> _CompositeSpecValuesView:
return _CompositeSpecItemsView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
+ step_mdp_static_only=step_mdp_static_only,
)._values()
def project(self, val: TensorDictBase) -> TensorDictBase:
@@ -6634,15 +6716,17 @@ class _CompositeSpecItemsView:
def __init__(
self,
composite: Composite,
- include_nested,
- leaves_only,
+ include_nested: bool,
+ leaves_only: bool,
*,
- is_leaf,
+ is_leaf: Callable[[type], bool] | None,
+ step_mdp_static_only: bool = False,
):
self.composite = composite
self.leaves_only = leaves_only
self.include_nested = include_nested
self.is_leaf = is_leaf
+ self.step_mdp_static_only = step_mdp_static_only
def __iter__(self):
from tensordict.base import _NESTED_TENSORS_AS_LISTS
@@ -6662,23 +6746,29 @@ def _iter_from_item(key, item):
include_nested=True,
leaves_only=self.leaves_only,
is_leaf=is_leaf,
+ step_mdp_static_only=self.step_mdp_static_only,
):
if not isinstance(subkey, tuple):
subkey = (subkey,)
yield (key, *subkey), subitem
- if not self.leaves_only and not _is_leaf(type(item)):
+ if (
+ (self.step_mdp_static_only and getattr(item, "step_mdp_static", False))
+ or (not self.leaves_only and not _is_leaf(type(item)))
+ or (not self.leaves_only or _is_leaf(type(item)))
+ ):
yield (key, item)
- elif not self.leaves_only or _is_leaf(type(item)):
- yield key, item
- for key, item in self._get_composite_items(is_leaf):
- if is_leaf is _NESTED_TENSORS_AS_LISTS and isinstance(
- item, _LazyStackedMixin
- ):
- for (i, spec) in enumerate(item._specs):
- yield from _iter_from_item(unravel_key((key, str(i))), spec)
- else:
- yield from _iter_from_item(key, item)
+ if not self.step_mdp_static_only or not getattr(
+ self.composite, "step_mdp_static", False
+ ):
+ for key, item in self._get_composite_items(is_leaf):
+ if is_leaf is _NESTED_TENSORS_AS_LISTS and isinstance(
+ item, _LazyStackedMixin
+ ):
+ for (i, spec) in enumerate(item._specs):
+ yield from _iter_from_item(unravel_key((key, str(i))), spec)
+ else:
+ yield from _iter_from_item(key, item)
def _get_composite_items(self, is_leaf):
diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py
index b84cee1f75f..9f0d6a6b7b2 100644
--- a/torchrl/envs/common.py
+++ b/torchrl/envs/common.py
@@ -7,6 +7,7 @@
import abc
import warnings
+import weakref
from copy import deepcopy
from functools import partial, wraps
from typing import Any, Callable, Iterator
@@ -539,6 +540,25 @@ def __init__(
self._run_type_checks = run_type_checks
self._allow_done_after_reset = allow_done_after_reset
+ _collector: weakref.ReferenceType[
+ LLMCollector # noqa: F821 # type: ignore
+ ] | None = None
+
+ def register_collector(
+ self, collector: DataCollectorBase # noqa: F821 # type: ignore
+ ):
+ """Registers a collector with the environment.
+
+ Args:
+ collector (DataCollectorBase): The collector to register.
+ """
+ self._collector = weakref.ref(collector)
+
+ @property
+ def collector(self) -> DataCollectorBase | None: # noqa: F821 # type: ignore
+ """Returns the collector associated with the container, if it exists."""
+ return self._collector() if self._collector is not None else None
+
def set_spec_lock_(self, mode: bool = True) -> EnvBase:
"""Locks or unlocks the environment's specs.
@@ -1222,6 +1242,56 @@ def observation_keys(self) -> list[NestedKey]:
)
return observation_keys
+ @property
+ @_cache_value
+ def _observation_keys_step_mdp(self) -> list[NestedKey]:
+ """The observation keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
+ observation_keys_leaves = sorted(
+ self.full_observation_spec.keys(True, True, step_mdp_static_only=True),
+ key=_repr_by_depth,
+ )
+ return observation_keys_leaves
+
+ @property
+ @_cache_value
+ def _state_keys_step_mdp(self) -> list[NestedKey]:
+ """The state keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
+ state_keys_leaves = sorted(
+ self.full_state_spec.keys(True, True, step_mdp_static_only=True),
+ key=_repr_by_depth,
+ )
+ return state_keys_leaves
+
+ @property
+ @_cache_value
+ def _action_keys_step_mdp(self) -> list[NestedKey]:
+ """The action keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
+ action_keys_leaves = sorted(
+ self.full_action_spec.keys(True, True, step_mdp_static_only=True),
+ key=_repr_by_depth,
+ )
+ return action_keys_leaves
+
+ @property
+ @_cache_value
+ def _done_keys_step_mdp(self) -> list[NestedKey]:
+ """The done keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
+ done_keys_leaves = sorted(
+ self.full_done_spec.keys(True, True, step_mdp_static_only=True),
+ key=_repr_by_depth,
+ )
+ return done_keys_leaves
+
+ @property
+ @_cache_value
+ def _reward_keys_step_mdp(self) -> list[NestedKey]:
+ """The reward keys of an environment that are static under step_mdp (i.e. to be copied as-is during step_mdp)."""
+ reward_keys_leaves = sorted(
+ self.full_reward_spec.keys(True, True, step_mdp_static_only=True),
+ key=_repr_by_depth,
+ )
+ return reward_keys_leaves
+
@property
def reward_key(self):
"""The reward key of an environment.
@@ -3409,7 +3479,7 @@ def step_mdp(self, next_tensordict: TensorDictBase) -> TensorDictBase:
@property
@_cache_value
- def _step_mdp(self):
+ def _step_mdp(self) -> Callable[[TensorDictBase], TensorDictBase]:
return _StepMDP(self, exclude_action=False)
def _rollout_stop_early(
@@ -3586,9 +3656,13 @@ def step_and_maybe_reset(
# done and truncated are in done_keys
# We read if any key is done.
tensordict_ = self._step_mdp(tensordict)
+ # if self._post_step_mdp_hooks is not None:
+ # tensordict_ = self._post_step_mdp_hooks(tensordict_)
tensordict_ = self.maybe_reset(tensordict_)
return tensordict, tensordict_
+ # _post_step_mdp_hooks: Callable[[TensorDictBase], TensorDictBase] | None = None
+
@property
@_cache_value
def _simple_done(self):
diff --git a/torchrl/envs/llm/__init__.py b/torchrl/envs/llm/__init__.py
index 42d1098f9d6..38457d0dd62 100644
--- a/torchrl/envs/llm/__init__.py
+++ b/torchrl/envs/llm/__init__.py
@@ -20,9 +20,11 @@
as_padded_tensor,
BrowserTransform,
DataLoadingPrimer,
+ KLComputation,
KLRewardTransform,
MCPToolTransform,
PythonInterpreter,
+ RetrieveKL,
RetrieveLogProb,
TemplateTransform,
Tokenizer,
@@ -33,12 +35,14 @@
"RetrieveLogProb",
"ChatEnv",
"DataLoadingPrimer",
+ "KLComputation",
"DatasetChatEnv",
"AddThinkingPrompt",
"GSM8KEnv",
"GSM8KPrepareQuestion",
"GSM8KRewardParser",
"IFEvalData",
+ "RetrieveKL",
"IFEvalEnv",
"IFEvalScoreData",
"IfEvalScorer",
diff --git a/torchrl/envs/llm/chat.py b/torchrl/envs/llm/chat.py
index e89cde26d7c..402b754f7da 100644
--- a/torchrl/envs/llm/chat.py
+++ b/torchrl/envs/llm/chat.py
@@ -4,134 +4,152 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
-import warnings
-
from typing import Any, Callable, Literal
import torch
-from tensordict import lazy_stack, TensorDict, TensorDictBase
+from tensordict import lazy_stack, TensorDictBase
+from tensordict.utils import _zip_strict
from torch.utils.data import DataLoader
from torchrl.data import Composite, NonTensor
-
-from torchrl.data.llm.chat import History
+from torchrl.data.llm.history import History
from torchrl.envs import EnvBase, TransformedEnv
from torchrl.envs.llm.transforms.dataloading import DataLoadingPrimer
+from torchrl.modules.llm.policies.common import ChatHistory, Text, Tokens
+
+
+def _default_collate_fn(batch):
+ # We want to rename the "text" key to "query"
+ # otherwise it will conflict with the "text" key in the tensordict returned by TorchRL components
+ if isinstance(batch, dict) and "text" in batch:
+ batch["query"] = batch.pop("text")
+ elif isinstance(batch, list):
+ for item in batch:
+ if "text" in item:
+ item["query"] = item.pop("text")
+ return batch
class ChatEnv(EnvBase):
- r"""A chat-based environment.
+ r"""A chat-based environment for LLMs, designed as a blank canvas for conversation and RL.
+
+ This environment is designed to work seamlessly with both :class:`~torchrl.modules.llm.policies.TransformersWrapper` and
+ :class:`~torchrl.modules.llm.policies.vLLMWrapper`. It provides the fundamental structure for managing conversation state
+ using the :class:`~torchrl.data.llm.History` format (or, alternatively, tokens or text), but is intentionally minimal to allow
+ maximum flexibility through transforms.
+
+ Core Functionality
+ The environment operates in three main modes:
+
+ - **History mode**: Uses :class:`~torchrl.data.llm.History` objects for conversation management
+ - **Text mode**: Uses simple text strings for input/output
+ - **Tokens mode**: Uses tokenized data for input/output
+
+ Reset Operation
+ During reset, the environment:
+
+ 1. Takes input text from the `data_key` (default: `"query"`) in the tensordict
+ 2. Creates a :class:`~torchrl.data.llm.History` object with the user's message
+ 3. Optionally prepends a system prompt if provided
+ 4. Formats the conversation according to the selected input mode (history, text, or tokens)
+ 5. Returns the formatted prompt ready for the LLM
- ChatEnv relies on the :class:`~torchrl.data.llm.History` format to output observations framed as a chat between
- various entities (typically with roles such as `"system"`, `"user"`, `"assistant"` etc.)
+ Step Operation
+ During step, the environment:
- The step function will execute the following operations:
+ 1. Takes the LLM's response (containing both prompt and generated text)
+ 2. Extracts the full conversation history
+ 3. Prepares the next prompt by setting the full history as the new prompt
+ 4. Returns the updated conversation state
- - Given a prompt (key `"text"`) and an answer string (key `"text_response"`, which is our action), the environment
- will generate a single string that is the concatenation of the two.
- - The text is fed to :meth:`torchrl.data.llm.History.from_text` to produce a full history of the chat so far. This
- should hopefully match the state of the history in the previous step, plus an extra step generated by the new
- action.
- - The last item of that history is then appended to the previous history (we don't replace the history in case
- it contains metadata that cannot be inferred directly from the prompt and response).
- - Optionally, the history is mapped back to a `"text"` entry that can be used to query the LLM in the next round
- of the policy.
+ This design enables natural multi-turn conversations where each step extends the conversation
+ history, making it ideal for dialogue systems and reinforcement learning applications.
- Args:
+ Integration with Transforms
+ ChatEnv is designed to be extended with transforms that add specific capabilities:
+
+ - **Reward computation**: :class:`~torchrl.envs.llm.transforms.KLRewardTransform` for KL divergence rewards
+ - **Tool execution**: :class:`~torchrl.envs.llm.transforms.PythonInterpreter` for Python code execution
+ - **Data loading**: :class:`~torchrl.envs.llm.transforms.DataLoadingPrimer` for loading prompts from datasets
+ - **Thinking prompts**: :class:`~torchrl.envs.llm.transforms.AddThinkingPrompt` for chain-of-thought reasoning
+
+ Keyword Args:
+ input_mode (Literal["history", "text", "tokens"]): The mode of input to the environment.
+ Defaults to `"history"`.
batch_size (torch.Size): Expected batch size of the input. Defaults to `(1,)` (null batch sizes such as `()`
are not recommended as they don't play well with generators).
- system_prompt (str, optional): an optional `"system"` prompt string to use during reset calls.
+ system_prompt (str, optional): An optional `"system"` prompt string to use during reset calls.
Defaults to `None`.
- apply_template (bool, optional): if `True` (and a tokenizer is passed), the history will be parsed to a string
- in the `"text"` entry of the output tensordict at reset time. Defaults to `False`.
-
- .. note:: If transforms are appended to the environment, the template will be applied to the history before the transform is applied.
- As transforms can encode tools, this means that the text returned by the environment may be incomplete.
- The :class:`~torchrl.modules.llm.vLLMWrapper` and :class:`~torchrl.modules.llm.TransformersWrapper`
- will apply the template to the history when queried if no `"text"` input is provided.
-
- tokenizer (transformers.PreTrainedTokenizer, *optional*): A tokenizer that will be used to tokenize the text.
+ tokenizer (transformers.PreTrainedTokenizer, optional): A tokenizer that will be used to tokenize the text.
Defaults to `None`.
- template_kwargs (dict[str, any], optional): keyword arguments passed to :meth:`~torchrl.data.llm.History.apply_chat_template`.
+ template_kwargs (dict[str, any], optional): Keyword arguments passed to :meth:`~torchrl.data.llm.History.apply_chat_template`.
Defaults to `None`.
- system_role (str, optional): the role of the system (at reset time). Defaults to `"system"`.
- user_role (str, optional): the role of the user (at reset time). Defaults to `"user"`.
- make_lazy (bool, optional): if `True`, the environment will return a lazy stack of tensordicts. This is the recommended setting
- for training, since it allows for efficient batching of environment outputs that may have different shapes or contents.
- Defaults to `True`.
+ system_role (str, optional): The role of the system (at reset time). Defaults to `"system"`.
+ user_role (str, optional): The role of the user (at reset time). Defaults to `"user"`.
+ policy_role (str, optional): The role of the policy/assistant. Defaults to `"assistant"`.
+ data_key (str, optional): The key of the data input to the env at reset time (from dataloader). Defaults to `"query"`.
+ device (torch.device, optional): The device to use for computations. Defaults to `None`.
Methods:
- reset (TensorDict): Resets the state of the environment. A tensordict or equivalent with a `"text"` entry must be passed.
- step (TensorDict): Makes a step in the environment (see above for a description of what `step` does).
- A tensordict or equivalent with a `"text"` entry must be passed.
+ reset (TensorDict): Resets the state of the environment. A tensordict or equivalent with a `"query"` entry
+ (originating from the dataloader) must be passed. This key name is defined as a class attribute `data_key`.
+ step (TensorDict): Makes a step in the environment. A tensordict or equivalent with the LLM's response must be passed.
+ The response key is defined as a class attribute `response_key`.
.. seealso:: To see examples of a `ChatEnv` in action, see :class:`~torchrl.envs.llm.chat.DatasetChatEnv`,
:class:`~torchrl.envs.llm.GSM8KEnv` and :class:`~torchrl.envs.llm.IFEvalEnv`.
Examples:
- >>> import pprint
- >>>
- >>> import transformers
- >>> from tensordict import TensorDict, set_list_to_stack
>>> from torchrl.envs.llm import ChatEnv
- >>> set_list_to_stack(True).set()
+ >>> from torchrl.data.llm import History
+ >>> from tensordict import TensorDict
+ >>>
+ >>> # Create a basic chat environment
+ >>> env = ChatEnv(
+ ... system_prompt="You are a helpful assistant.",
+ ... input_mode="history"
+ ... )
>>>
- >>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
+ >>> # Reset with a user query
+ >>> reset_data = TensorDict({"query": "Hello, how are you?"}, batch_size=(1,))
+ >>> obs = env.reset(reset_data)
+ >>> print(obs["history"].prompt) # History with system prompt + user message
>>>
- >>> env = ChatEnv(batch_size=(1,), tokenizer=tokenizer, apply_template=True, system_prompt="I'm system, do what I want.")
- >>> td_reset = env.reset(TensorDict(text=["I'm the user. I'm going to tell you a little about something."], batch_size=(1,)))
- >>> pprint.pprint(f'{td_reset["history"]=}')
- ('td_reset["history"]=History(\n'
- ' content=NonTensorStack(\n'
- ' [["I\'m system, do what I want.", "I\'m the user. I\'...,\n'
- ' batch_size=torch.Size([1, 2]),\n'
- ' device=None),\n'
- ' role=NonTensorStack(\n'
- " [['system', 'user']],\n"
- ' batch_size=torch.Size([1, 2]),\n'
- ' device=None),\n'
- ' batch_size=torch.Size([1, 2]),\n'
- ' device=None,\n'
- ' is_shared=False)')
- >>> pprint.pprint(f'{td_reset["text"]=}')
- ('td_reset["text"]=["<|im_start|>system\\nI\'m system, do what I '
- "want.<|im_end|>\\n<|im_start|>user\\nI'm the user. I'm going to tell you a "
- 'little about something.<|im_end|>\\n<|im_start|>assistant\\n"]')
- >>> td_action = td_reset.set("text_response", ["This is the action from the assistant!<|im_end|>"])
- >>> td_next = env.step(td_action)
- >>> pprint.pprint(f'{td_next["next", "history"]=}')
- ('td_next["next", "history"]=History(\n'
- ' content=NonTensorStack(\n'
- ' [["I\'m system, do what I want.", "I\'m the user. I\'...,\n'
- ' batch_size=torch.Size([1, 3]),\n'
- ' device=None),\n'
- ' role=NonTensorStack(\n'
- " [['system', 'user', 'assistant']],\n"
- ' batch_size=torch.Size([1, 3]),\n'
- ' device=None),\n'
- ' batch_size=torch.Size([1, 3]),\n'
- ' device=None,\n'
- ' is_shared=False)')
- >>> pprint.pprint(f'{td_next["next", "text"]=}')
- ('td_next["next", "text"]=["<|im_start|>system\\nI\'m system, do what I '
- "want.<|im_end|>\\n<|im_start|>user\\nI'm the user. I'm going to tell you a "
- 'little about something.<|im_end|>\\n<|im_start|>assistant\\nThis is the '
- 'action from the assistant!<|im_end|>\\n<|im_start|>assistant\\n"]')
+ >>> # Simulate LLM response and step
+ >>> response_data = TensorDict({
+ ... "history": History.from_chats([[
+ ... {"role": "system", "content": "You are a helpful assistant."},
+ ... {"role": "user", "content": "Hello, how are you?"},
+ ... {"role": "assistant", "content": "I'm doing well, thank you!"}
+ ... ]])
+ ... }, batch_size=(1,))
+ >>> next_obs = env.step(response_data)
+ >>> print(next_obs["history"].prompt) # Full conversation history
"""
+ # Nested key corresponding to the text input to the LLM
+ text_key = ("text", "prompt")
+ # Nested key corresponding to the response from the LLM
+ response_key = ("text", "response")
+ # Nested key corresponding to the data input to the env at reset time (from dataloader)
+ data_key = "query"
+
def __init__(
self,
+ *,
+ input_mode: Literal["history", "text"] = "history",
batch_size: tuple | torch.Size | None = None,
system_prompt: str | None = None,
- apply_template: bool | None = None,
tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
template_kwargs: dict[str, Any] | None = None,
system_role: str = "system",
user_role: str = "user",
policy_role: str | None = "assistant",
- make_lazy: bool = True,
+ data_key: str | None = None,
+ device: torch.device | None = None,
):
+ self.input_mode = input_mode
if batch_size is None:
batch_size = (1,)
if isinstance(batch_size, int):
@@ -140,100 +158,235 @@ def __init__(
batch_size = torch.Size(batch_size)
if batch_size == ():
raise ValueError(f"{type(self).__name__} must have at least one dimension")
+ if data_key is not None:
+ self.data_key = data_key
+ super().__init__(batch_size=batch_size, device=device)
+ self.batch_size = batch_size
- super().__init__(batch_size=batch_size)
- self.full_observation_spec = Composite(
- history=History.default_spec(shape=batch_size + (-1,)),
- shape=batch_size,
- )
- self.full_state_spec = self.full_observation_spec.clone()
- self.full_state_spec["text"] = NonTensor(
- shape=self.batch_size, example_data="a string", device=self.device
- )
self.system_prompt = system_prompt
- self.apply_template = (
- apply_template or (template_kwargs is not None) or (tokenizer is not None)
- )
- self.tokenizer = tokenizer
+
if template_kwargs is None:
template_kwargs = {}
- # FIXME: what to do if True?
- template_kwargs.setdefault("tokenize", False)
self.template_kwargs = template_kwargs
- if self.apply_template:
- self.full_observation_spec["text"] = NonTensor(
- shape=self.batch_size, example_data="a string", device=self.device
- )
- self.full_action_spec = Composite(
- text_response=NonTensor(
- shape=self.batch_size, example_data="a string", device=self.device
- ),
- batch_size=self.batch_size,
- )
+
self.system_role = system_role
self.user_role = user_role
self.policy_role = policy_role
- self.make_lazy = make_lazy
+ self.tokenizer = tokenizer
- def _step(self, tensordict):
- # Expect action to be a "text_response" string
- action = tensordict["text_response"]
- # Find the total text
- text = tensordict["text"]
- if isinstance(text, str):
- text = [text]
- action = [action]
- text = [t + a for t, a in zip(text, action)]
- # Convert text to a history
- chat_template_name = None
- if self.tokenizer is not None:
- name_or_path = self.tokenizer.name_or_path
- if "qwen" in name_or_path.lower():
- chat_template_name = "qwen"
- parsed_history = History.from_text(text, chat_template_name=chat_template_name)
- # Isolate last element, which should be our action
- local_history = parsed_history[..., -1]
- # Get previous history
- history = tensordict["history"]
- # Check that history has one more item than before
- if history.shape[-1] <= parsed_history.shape[-1]:
- warnings.warn(
- "The parsed history has fewer or the same number than the last element in history."
- )
- if self.policy_role is not None:
- # Iterate over batch and check policy role
- for lh in local_history.unbind(0):
- if lh.role != self.policy_role:
- raise ValueError(
- "The role received in the last block parsed from the policy "
- f"output does not match the expected policy role: received {lh.role} but expected {self.policy_role}.\n"
- f"Parsed input: {text=}\n"
- f"Parsed history: {parsed_history=}\n"
- f"Final element: {local_history=}"
- )
- # Append history item
- history = history.append(local_history, inplace=False)
- # FIXME: consider done to be always False
- td_out = lazy_stack(
- list(
- TensorDict(
- history=history,
- done=torch.zeros(tensordict.shape + (1,), dtype=torch.bool),
- batch_size=self.batch_size,
- ).unbind(0)
- )
+ self._make_specs()
+
+ def _make_specs(self):
+ if self.input_mode == "history":
+ self._make_specs_history()
+ elif self.input_mode == "text":
+ self._make_specs_text()
+ elif self.input_mode == "tokens":
+ self._make_specs_tokens()
+ else:
+ raise ValueError(f"Invalid input mode: {self.input_mode}")
+
+ def _make_specs_history(self):
+ # we output prompt
+ self.full_observation_spec = Composite(
+ history=ChatHistory.default_spec(shape=self.batch_size, keys=["prompt"]).to(
+ self.device
+ ),
+ shape=self.batch_size,
+ device=self.device,
+ )
+ # We receive prompt, response and full
+ self.full_action_spec = Composite(
+ history=ChatHistory.default_spec(shape=self.batch_size, keys=["full"]).to(
+ self.device
+ ),
+ shape=self.batch_size,
+ device=self.device,
+ )
+ self.full_state_spec = Composite(
+ {
+ self.data_key: NonTensor(
+ example_data="a string", shape=self.batch_size, device=self.device
+ )
+ },
+ shape=self.batch_size,
+ device=self.device,
+ )
+
+ def _make_specs_text(self):
+ # we output prompt
+ self.full_observation_spec = Composite(
+ text=Text.default_spec(shape=self.batch_size, keys=["prompt"]).to(
+ self.device
+ ),
+ shape=self.batch_size,
+ device=self.device,
+ )
+ # We receive prompt, response and full
+ self.full_action_spec = Composite(
+ text=Text.default_spec(shape=self.batch_size, keys=["full"]).to(
+ self.device
+ ),
+ shape=self.batch_size,
+ device=self.device,
+ )
+ self.full_state_spec = Composite(
+ {
+ self.data_key: NonTensor(
+ example_data="a string", shape=self.batch_size, device=self.device
+ )
+ },
+ shape=self.batch_size,
+ device=self.device,
+ )
+
+ def _make_specs_tokens(self):
+ # we output prompt
+ self.full_observation_spec = Composite(
+ tokens=Tokens.default_spec(shape=self.batch_size, keys=["prompt"]).to(
+ self.device
+ ),
+ shape=self.batch_size,
+ device=self.device,
+ )
+ # We receive prompt, response and full
+ self.full_action_spec = Composite(
+ tokens=Tokens.default_spec(shape=self.batch_size, keys=["full"]).to(
+ self.device
+ ),
+ shape=self.batch_size,
+ device=self.device,
+ )
+ self.full_state_spec = Composite(
+ {
+ self.data_key: NonTensor(
+ example_data="a string", shape=self.batch_size, device=self.device
+ )
+ },
+ shape=self.batch_size,
+ device=self.device,
+ )
+
+ @classmethod
+ def from_dataloader(
+ cls,
+ dataloader: DataLoader,
+ *,
+ repeats: int | None = None,
+ device: torch.device | None = None,
+ group_repeats: bool = False,
+ batch_size: tuple | torch.Size | None = None,
+ primers: Composite | None = None,
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
+ template_kwargs: dict[str, Any] | None = None,
+ input_mode: Literal["history", "text", "tokens"] = "history",
+ data_key: str | None = None,
+ system_prompt: str | None = None,
+ ):
+ """Create a chat environment from a dataloader.
+
+ Args:
+ dataloader (DataLoader): The dataloader to use.
+
+ Keyword Args:
+ repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
+ based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
+ device (torch.device | None, optional): The device to use for computations. Defaults to None.
+ group_repeats (bool, optional): Whether to group repeated samples together. Defaults to `False`.
+ batch_size (tuple | torch.Size | None, optional): The batch size for data loading. Defaults to `1`.
+ primers (Composite | None, optional): The primers to use for data loading. Defaults to `None`.
+ tokenizer (transformers.AutoTokenizer | None, optional): The tokenizer to use for text processing. Defaults to `None`.
+ template_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for the template. Defaults to `None`.
+ input_mode (Literal["history", "text", "tokens"], optional): The mode of input to the environment. Defaults to `"history"`.
+ data_key (str, optional): The spec of the data returned by the dataloader (or better, its collate_fn).
+ Defaults to `None` (automatically determined based on the input_mode).
+ system_prompt (str | None, optional): The system prompt to use for the environment. Defaults to `None`.
+
+ Returns:
+ DatasetChatEnv: The chat environment.
+ """
+ return DatasetChatEnv.from_dataloader(
+ dataloader=dataloader,
+ repeats=repeats,
+ device=device,
+ group_repeats=group_repeats,
+ batch_size=batch_size,
+ primers=primers,
+ tokenizer=tokenizer,
+ template_kwargs=template_kwargs,
+ input_mode=input_mode,
+ data_key=data_key,
+ system_prompt=system_prompt,
)
- if self.apply_template:
- td_out["text"] = history.apply_chat_template(
- tokenizer=self.tokenizer, **self.template_kwargs
- )
- return td_out
- def _reset(self, tensordict: TensorDictBase | None):
+ # def _post_step_mdp_hooks(self, tensordict: TensorDictBase) -> TensorDictBase:
+ # """Allows modification of the tensordict after the step_mdp."""
+ # if self.input_mode == "history":
+ # tensordict.exclude(
+ # ("history", "response"), ("history", "full"), inplace=True
+ # )
+ # if self.input_mode in ("text", "history"):
+ # tensordict.exclude(("text", "response"), ("text", "full"), inplace=True)
+ # if self.input_mode in ("tokens", "history", "text"):
+ # tensordict.exclude(("tokens", "response"), ("tokens", "full"), inplace=True)
+ # if "log_probs" in tensordict.keys():
+ # tensordict.exclude(
+ # ("log_probs", "response"), ("log_probs", "full"), inplace=True
+ # )
+ # return tensordict
+
+ def _step(self, tensordict):
+ if self.input_mode == "history":
+ return self._step_history(tensordict)
+ if self.input_mode in ("text", "history"):
+ return self._step_text(tensordict)
+ if self.input_mode in ("tokens", "history", "text"):
+ return self._step_tokens(tensordict)
+ else:
+ raise ValueError(f"Invalid input mode: {self.input_mode}")
+
+ def _step_history(self, tensordict):
+ """Step the environment in history mode."""
+ # get history from tensordict
+ chat_history: ChatHistory = tensordict["history"]
+ # prompt = chat_history.prompt
+ full = chat_history.full
+ # response = chat_history.response
+ empty_td = tensordict.empty(device=self.device)
+ # Old full will be new prompt - can be modified at will
+ new_history = ChatHistory(prompt=full)
+ empty_td.set("history", new_history)
+ return empty_td
+
+ def _step_text(self, tensordict):
+ """Step the environment in text mode."""
+ # get text from tensordict
+ text: Text = tensordict["text"]
+ full = text.full
+ empty_td = tensordict.empty(device=self.device)
+ new_history = Text(prompt=full)
+ empty_td.set("text", new_history)
+ return empty_td
+
+ def _step_tokens(self, tensordict):
+ """Step the environment in tokens mode."""
+ # get tokens from tensordict
+ tokens: Tokens = tensordict["tokens"]
+ full = tokens.full
+ empty_td = tensordict.empty(device=self.device)
+ new_history = Tokens(prompt=full)
+ empty_td.set("tokens", new_history)
+ return empty_td
+
+ def _reset(self, tensordict: TensorDictBase | None, **kwargs):
if tensordict is None:
raise RuntimeError(f"{type(self).__name__} expects a tensordict as input")
# Find the total text
- content = tensordict.get("text")
+ content = tensordict.get(self.data_key)
+ if content is None:
+ raise RuntimeError(
+ f"{type(self).__name__} expects a tensordict with a {self.data_key} key, got {tensordict.keys()}"
+ )
if content.batch_size != self.batch_size:
for s in reversed(self.batch_size):
content = [content for _ in range(s)]
@@ -254,23 +407,58 @@ def _reset(self, tensordict: TensorDictBase | None):
history = lazy_stack([history_system, history], -1)
else:
history = history.unsqueeze(-1)
- result = TensorDict(
- history=history,
- done=torch.zeros(tensordict.shape + (1,), dtype=torch.bool),
- batch_size=self.batch_size,
- )
- if self.make_lazy:
- result = result.unbind(0)
- result = lazy_stack(list(result), dim=0)
- elif tensordict._lazy:
- result = result.unbind(tensordict.stack_dim)
- result = lazy_stack(list(result), dim=tensordict.stack_dim)
- result.update(tensordict.exclude(*result.keys(True)))
- if self.apply_template:
- template = history.apply_chat_template(
- tokenizer=self.tokenizer, **self.template_kwargs
+
+ # Now that we have the history, call the specific reset method
+ if self.input_mode == "history":
+ return (
+ self._reset_history(tensordict, history)
+ .update(tensordict)
+ .to_lazystack(0)
+ )
+ elif self.input_mode == "text":
+ return (
+ self._reset_text(tensordict, history).update(tensordict).to_lazystack(0)
+ )
+ elif self.input_mode == "tokens":
+ return (
+ self._reset_tokens(tensordict, history)
+ .update(tensordict)
+ .to_lazystack(0)
)
- result["text"] = template
+ else:
+ raise ValueError(f"Invalid input mode: {self.input_mode}")
+
+ def _reset_history(self, tensordict: TensorDictBase, history: History):
+ # Simplest case: history is the prompt
+ chat_history = ChatHistory._from_tensordict(
+ tensordict.empty(device=self.device)
+ )
+ chat_history.prompt = history
+ return tensordict.empty(device=self.device).set("history", chat_history)
+
+ def _reset_text(self, tensordict: TensorDictBase, history: History):
+ # We need to parse the history to a text
+ text = history.apply_chat_template(
+ tokenizer=self.tokenizer, add_generation_prompt=True, **self.template_kwargs
+ )
+ txt = Text._from_tensordict(tensordict.empty())
+ txt.prompt = text
+ result = tensordict.empty(device=self.device).set("text", txt)
+ return result
+
+ def _reset_tokens(self, tensordict: TensorDictBase, history: History):
+ # We need to parse the history to a tokens
+ tokens = history.apply_chat_template(
+ tokenizer=self.tokenizer,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ return_dict=True,
+ **self.template_kwargs,
+ )
+ tokens_obj = Tokens._from_tensordict(tensordict.empty().to_lazystack(0))
+ for to, tok in _zip_strict(tokens_obj.unbind(0), tokens["input_ids"]):
+ to.prompt = tok
+ result = tensordict.empty(device=self.device).set("tokens", tokens_obj)
return result
def _set_seed(self, seed):
@@ -302,7 +490,12 @@ class DatasetChatEnv(TransformedEnv):
template_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for the template. Defaults to `None`.
apply_template (bool | None, optional): Whether to apply the template to the text. Defaults to `False`.
collate_fn (Callable | None, optional): A custom collate function for data loading. If `None`, a default
- collate function is used. Defaults to `None`.
+ collate function is used that renames the `"text"` key to `"query"` to avoid conflicts with the `"text"` key
+ in the tensordict returned by TorchRL components. Defaults to `None`.
+ input_mode (Literal["history", "text", "tokens"], optional): The mode of input to the environment. Defaults to `"history"`.
+ data_key (str, optional): The spec of the data returned by the dataloader (or better, its collate_fn).
+ Defaults to `None` (automatically determined based on the input_mode).
+ system_prompt (str | None, optional): The system prompt to use for the environment. Defaults to `None`.
.. seealso:: `DatasetChatEnv` is a thin wrapper around :class:`~torchrl.envs.llm.ChatEnv` bucketed with a
:class:`~torchrl.envs.llm.DataLoadingPrimer` transform. See these two classes for more insight on data format
@@ -331,6 +524,10 @@ def __init__(
template_kwargs: dict[str, Any] | None = None,
apply_template: bool | None = False,
collate_fn: Callable[[Any], Any] | None = None,
+ input_mode: Literal["history", "text", "tokens"] = "history",
+ data_key: str | None = None,
+ primers: Composite | None = None,
+ system_prompt: str | None = None,
):
from datasets import load_dataset
from tensordict import list_to_stack
@@ -343,11 +540,11 @@ def __init__(
batch_size = (num_envs,)
- dataset = load_dataset(dataset, name)
- if split is None and "train" in dataset:
+ dataset_obj = load_dataset(dataset, name)
+ if split is None and "train" in dataset_obj:
split = "train"
if split is not None:
- dataset = dataset[split]
+ dataset_obj = dataset_obj[split]
# Env
if seed is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
@@ -355,28 +552,117 @@ def __init__(
generator.manual_seed(seed)
dataloader = DataLoader( # noqa: TOR401
- dataset,
+ dataset_obj,
batch_size=batch_size_dl,
shuffle=shuffle,
- collate_fn=collate_fn,
+ collate_fn=collate_fn if collate_fn is not None else _default_collate_fn,
generator=generator,
)
+ self._from_dataloader(
+ self,
+ dataloader=dataloader,
+ repeats=repeats,
+ device=device,
+ group_repeats=group_repeats,
+ batch_size=batch_size,
+ primers=primers,
+ tokenizer=tokenizer,
+ template_kwargs=template_kwargs,
+ input_mode=input_mode,
+ data_key=data_key,
+ system_prompt=system_prompt,
+ )
+
+ @classmethod
+ def from_dataloader(
+ cls,
+ dataloader: DataLoader,
+ *,
+ repeats: int | None = None,
+ device: torch.device | None = None,
+ group_repeats: bool = False,
+ batch_size: tuple | torch.Size | None = None,
+ primers: Composite | None = None,
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
+ template_kwargs: dict[str, Any] | None = None,
+ input_mode: Literal["history", "text", "tokens"] = "history",
+ data_key: str | None = None,
+ system_prompt: str | None = None,
+ ):
+ """Create a chat environment from a dataloader.
+
+ Args:
+ dataloader (DataLoader): The dataloader to use.
+
+ Keyword Args:
+ repeats (int | None, optional): The number of times to repeat each sample from the dataset (mainly for Monte-Carlo
+ based value estimation). If `None`, the dataset is not repeated. Defaults to `None`.
+ device (torch.device | None, optional): The device to use for computations. Defaults to None.
+ group_repeats (bool, optional): Whether to group repeated samples together. Defaults to `False`.
+ batch_size (tuple | torch.Size | None, optional): The batch size for data loading. Defaults to `1`.
+ primers (Composite | None, optional): The primers to use for data loading. Defaults to `None`.
+ tokenizer (transformers.AutoTokenizer | None, optional): The tokenizer to use for text processing. Defaults to `None`.
+ template_kwargs (dict[str, Any] | None, optional): Additional keyword arguments for the template. Defaults to `None`.
+ input_mode (Literal["history", "text", "tokens"], optional): The mode of input to the environment. Defaults to `"history"`.
+ data_key (str, optional): The spec of the data returned by the dataloader (or better, its collate_fn).
+ Defaults to `None` (automatically determined based on the input_mode).
+ system_prompt (str | None, optional): The system prompt to use for the environment. Defaults to `None`.
+ Returns:
+ ChatEnv: The chat environment.
+ """
+ self = cls.__new__(cls)
+ return cls._from_dataloader(
+ self,
+ dataloader,
+ repeats=repeats,
+ device=device,
+ group_repeats=group_repeats,
+ batch_size=batch_size,
+ primers=primers,
+ tokenizer=tokenizer,
+ template_kwargs=template_kwargs,
+ input_mode=input_mode,
+ data_key=data_key,
+ system_prompt=system_prompt,
+ )
+
+ @classmethod
+ def _from_dataloader(
+ cls,
+ self,
+ dataloader,
+ *,
+ repeats: int | None = None,
+ device: torch.device | None = None,
+ group_repeats: bool = False,
+ batch_size: tuple | torch.Size | None = None,
+ primers: Composite | None = None,
+ tokenizer: transformers.AutoTokenizer | None = None, # noqa: F821
+ template_kwargs: dict[str, Any] | None = None,
+ input_mode: Literal["history", "text", "tokens"] = "history",
+ data_key: str | None = None,
+ system_prompt: str | None = None,
+ ):
primer = DataLoadingPrimer(
dataloader=dataloader,
repeats=repeats,
device=device,
group_repeats=group_repeats,
batch_size=batch_size,
+ primers=primers,
)
env_base = ChatEnv(
batch_size=batch_size,
- system_prompt=self.SYSTEM_PROMPT,
+ system_prompt=cls.SYSTEM_PROMPT if system_prompt is None else system_prompt,
tokenizer=tokenizer,
template_kwargs=template_kwargs,
- apply_template=apply_template,
+ input_mode=input_mode,
+ data_key=data_key,
+ device=device,
)
- return super().__init__(env_base, primer)
+ TransformedEnv.__init__(self, env_base, primer)
+ return self
def reset_dataloader(self):
"""Reset the dataloader.
@@ -386,5 +672,6 @@ def reset_dataloader(self):
Returns:
self: The environment itself.
"""
- self.transform[0].reset_dataloader()
+ if hasattr(self.transform, "__getitem__"):
+ self.transform[0].reset_dataloader()
return self
diff --git a/torchrl/envs/llm/datasets/gsm8k.py b/torchrl/envs/llm/datasets/gsm8k.py
index 903f50d75f7..49897f2d53f 100644
--- a/torchrl/envs/llm/datasets/gsm8k.py
+++ b/torchrl/envs/llm/datasets/gsm8k.py
@@ -5,7 +5,7 @@
from __future__ import annotations
import warnings
-from typing import Any, Callable
+from typing import Any, Callable, Literal
import torch
from tensordict import NestedKey, TensorDict, TensorDictBase
@@ -71,7 +71,7 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
def _collate_fn(batch):
batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
- batch.rename_key_("question", "text")
+ batch.rename_key_("question", "query")
return batch
@@ -123,7 +123,13 @@ def make_gsm8k_env(
env.append_transform(StepCounter(max_steps=1))
if tokenizer is not None:
- env.append_transform(GSM8KRewardParser(tokenizer=tokenizer))
+ env.append_transform(
+ GSM8KRewardParser(
+ tokenizer=tokenizer,
+ input_mode="text",
+ in_keys=["text_response", "answer"],
+ )
+ )
else:
warnings.warn("No tokenizer specified - reward will not be assigned.")
@@ -154,6 +160,7 @@ class GSM8KEnv(DatasetChatEnv):
collate_fn (Callable | None, optional): A custom collate function for data loading. If `None`, a default
collate function is used. Defaults to `None`.
max_steps (int, optional): The maximum number of steps allowed in an episode. Defaults to `1`.
+ input_mode (Literal["history", "text", "tokens"], optional): The mode of input to use. Defaults to `"history"`.
Examples:
>>> import transformers
@@ -304,6 +311,7 @@ def __init__(
compute_reward: bool = True,
collate_fn: Callable | None = None,
max_steps: int = 1,
+ input_mode: Literal["history", "text", "tokens"] = "history",
):
if collate_fn is None:
collate_fn = _collate_fn
@@ -321,6 +329,7 @@ def __init__(
template_kwargs=template_kwargs,
apply_template=apply_template,
collate_fn=collate_fn,
+ input_mode=input_mode,
)
if max_steps:
self.append_transform(StepCounter(max_steps=max_steps))
diff --git a/torchrl/envs/llm/datasets/ifeval.py b/torchrl/envs/llm/datasets/ifeval.py
index 4c3e7e8866e..0e23846f4aa 100644
--- a/torchrl/envs/llm/datasets/ifeval.py
+++ b/torchrl/envs/llm/datasets/ifeval.py
@@ -4,14 +4,14 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
-from typing import Any, Callable
+from typing import Any, Callable, Literal
import torch
-from tensordict import TensorClass, TensorDict
+import transformers
+from tensordict import NonTensorData, NonTensorStack, TensorClass, TensorDict
+from torchrl.data import Composite, NonTensor, Unbounded
from torchrl.envs import StepCounter
-
from torchrl.envs.llm.chat import DatasetChatEnv
-
from torchrl.envs.llm.reward.ifeval import IfEvalScorer
@@ -19,9 +19,10 @@ class IFEvalData(TensorClass["nocast"]):
"""A tensorclass for IFEval dta."""
key: torch.Tensor
- instruction_id_list: str
+ instruction_id_list: list[str]
kwargs: list[dict]
- text: str
+ query: str
+
# Reponses and additional fields
response: str | None = None
tokens: torch.Tensor | None = None
@@ -29,11 +30,63 @@ class IFEvalData(TensorClass["nocast"]):
logits: torch.Tensor | None = None
reward: torch.Tensor | None = None
+ @classmethod
+ def default_spec(
+ cls, shape: torch.Size, device: torch.device | None = None
+ ) -> Composite:
+ return Composite(
+ key=Unbounded(shape=shape, dtype=torch.int64, device=device),
+ instruction_id_list=NonTensor(
+ shape=shape,
+ device=device,
+ feature_dims=0,
+ example_data=["punctuation:no_comma"],
+ ),
+ kwargs=NonTensor(
+ shape=shape,
+ device=device,
+ feature_dims=0,
+ example_data={
+ "num_highlights": None,
+ "relation": None,
+ "num_placeholders": None,
+ },
+ ),
+ query=NonTensor(
+ shape=shape,
+ device=device,
+ example_data="Plan a 2 week Europe trip and visit London, Paris, and Rome. Answer in all caps. The response must contain at least 8 placeholders (i.e., [restaurant]).",
+ ),
+ shape=shape,
+ step_mdp_static=True,
+ data_cls=cls,
+ )
+
def _collate_fn(batch):
batch = torch.stack([TensorDict.from_any(_batch) for _batch in batch])
- batch.rename_key_("prompt", "text")
- return IFEvalData.from_tensordict(batch)
+ batch.rename_key_("prompt", "query")
+ # we want instruction_id_list and kwargs to be lists, but not NonTensorStacks
+ instruction_id_list = batch["instruction_id_list"]
+ # instruction_id_list should be a list of lists
+ instruction_id_list = NonTensorStack(
+ *[
+ NonTensorData([item] if not isinstance(item, list) else item)
+ for item in instruction_id_list
+ ]
+ )
+ kwargs = batch["kwargs"]
+ kwargs = NonTensorStack(
+ *[
+ NonTensorData([item] if not isinstance(item, list) else item)
+ for item in kwargs
+ ]
+ )
+ batch.set("instruction_id_list", instruction_id_list)
+ batch.set("kwargs", kwargs)
+ # we don't need a tensorclass here
+ return batch
+ # return IFEvalData.from_tensordict(batch)
class IFEvalEnv(DatasetChatEnv):
@@ -60,6 +113,7 @@ class IFEvalEnv(DatasetChatEnv):
collate_fn (Callable | None, optional): A custom collate function for data loading. If `None`, a default
collate function is used. Defaults to `None`.
max_steps (int, optional): The maximum number of steps allowed in an episode. Defaults to `1`.
+ input_mode (Literal["history", "text", "tokens"], optional): The mode of input to use. Defaults to `"history"`.
Examples:
>>> import transformers
@@ -137,11 +191,29 @@ class IFEvalEnv(DatasetChatEnv):
"""
- SYSTEM_PROMPT = """A conversation between User and Assistant.
-You are tasked with responding to user queries in a very specific format.
-When given a task or question, first think through the problem and provide your thought process between and tags.
-Then, give your final answer or response between and tags.
-You will be assessed by the content of the answer block only, so make sure it contains all the required information, and only that."""
+ SYSTEM_PROMPT = """You are a helpful AI assistant that follows instructions extremely well.
+
+IMPORTANT: You must respond in a specific format for every task:
+
+1. First, think through the problem step by step and write your reasoning between and tags
+2. Then, provide your final answer between and tags
+
+CRITICAL RULES:
+- ALWAYS use ... and ... tags exactly as shown
+- Do NOT use , , or any other tag variations
+- Your section will be evaluated, so make it complete and accurate
+- Follow ALL specific requirements in the user's request (formatting, content, etc.)
+- If the user asks for placeholders like [restaurant], include them exactly as requested
+- Pay attention to capitalization, punctuation, and other formatting requirements
+
+Example format:
+
+I need to analyze what the user is asking for...
+[Your reasoning here]
+
+
+[Your final answer here, following all user requirements]
+ """
def __init__(
self,
@@ -160,6 +232,7 @@ def __init__(
compute_reward: bool = True,
collate_fn: Callable | None = None,
max_steps: int = 1,
+ input_mode: Literal["history", "text", "tokens"] = "history",
):
if collate_fn is None:
collate_fn = _collate_fn
@@ -176,6 +249,9 @@ def __init__(
template_kwargs=template_kwargs,
apply_template=apply_template,
collate_fn=collate_fn,
+ input_mode=input_mode,
+ data_key="query",
+ primers=IFEvalData.default_spec((num_envs,), device),
)
if max_steps:
self.append_transform(StepCounter(max_steps=max_steps))
diff --git a/torchrl/envs/llm/envs.py b/torchrl/envs/llm/envs.py
index 7b560678fd5..fc97f887109 100644
--- a/torchrl/envs/llm/envs.py
+++ b/torchrl/envs/llm/envs.py
@@ -119,6 +119,7 @@ def __init__(
as_llm_data: bool = False,
eos_token_id: int | None = None,
) -> None:
+ self._warn_deprecated()
self.as_llm_data = as_llm_data
if token_key is None:
token_key = self._DEFAULT_TOKEN_KEY
@@ -255,6 +256,13 @@ def __init__(
terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device),
)
+ @classmethod
+ def _warn_deprecated(cls):
+ warnings.warn(
+ "LLMEnv is deprecated. Please use ChatEnv instead.",
+ category=DeprecationWarning,
+ )
+
@classmethod
def from_dataloader(
cls,
@@ -346,6 +354,8 @@ def from_dataloader(
Returns:
LLMEnv: The created LLMEnv instance.
"""
+ cls._warn_deprecated()
+
from torchrl.envs.llm import DataLoadingPrimer, Tokenizer
if str_key is None:
diff --git a/torchrl/envs/llm/reward/gsm8k.py b/torchrl/envs/llm/reward/gsm8k.py
index 2edbc001d8d..041bc1424a1 100644
--- a/torchrl/envs/llm/reward/gsm8k.py
+++ b/torchrl/envs/llm/reward/gsm8k.py
@@ -4,24 +4,33 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
-import torch
-from tensordict import NestedKey, TensorDict, TensorDictBase
-from tensordict.utils import _zip_strict
+from typing import Literal
+import torch
+from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase
+from tensordict.utils import _zip_strict, is_non_tensor
from torchrl.data import Composite, Unbounded
from torchrl.envs import Transform
+from torchrl.envs.common import EnvBase
class GSM8KRewardParser(Transform):
"""Reward parser for GSM8KEnv or make_gsm8k_env.
+ This parser automatically detects the input_mode from the parent environment and handles
+ responses accordingly:
+ - "history" mode: response is in ("history", "response") and is a History object
+ - "text" mode: response is in ("text", "response") and is text
+ - "tokens" mode: response is in ("tokens", "response") and is tokens
+
Args:
- tokenizer (AutoTokenizer from transformers): the tokenizer asssociated with the model.
- in_keys (list of NestedKey): the input keys. Defaults to `["text_response", "answer"]`.
+ tokenizer (AutoTokenizer from transformers): the tokenizer associated with the model.
+ in_keys (list of NestedKey): the input keys. If None, will be automatically determined based on parent's input_mode.
out_keys (list of NestedKey): the output keys. Defaults to `[ "reward_answer", "reward_think", "reward_right", "reward_contained", "reward", "success"]`.
eos_token (str): the end of sentence token. Defaults to `tokenizer.eos_token` if not provided.
set_done_if_answer (bool): whether to set the done flag to `True` when an answer is present. Defaults to `True`.
-
+ input_mode (Literal["history", "text", "tokens"]): the input mode of the parent environment.
+ Defaults to `None` (will be automatically determined based on parent's input_mode).
"""
def __init__(
@@ -31,6 +40,7 @@ def __init__(
out_keys: list[NestedKey] | None = None,
eos_token: str | None = None,
set_done_if_answer: bool = True,
+ input_mode: Literal["history", "text", "tokens"] | None = None,
):
super().__init__()
self.tokenizer = tokenizer
@@ -42,12 +52,8 @@ def __init__(
else None
)
self.set_done_if_answer = set_done_if_answer
- if in_keys is None:
- in_keys = ["text_response", "answer"]
- if not isinstance(in_keys, list) or len(in_keys) != 2:
- raise ValueError(
- f"{type(self).__name__} requires in_keys to be of type list and have 2 elements."
- )
+ self._input_mode = input_mode
+
if out_keys is None:
out_keys = [
"reward_answer",
@@ -57,7 +63,42 @@ def __init__(
"reward",
"success",
]
- super().__init__(in_keys, out_keys)
+ super().__init__()
+ if in_keys is not None:
+ self.in_keys = in_keys
+ self.out_keys = out_keys
+
+ def _maybe_get_in_keys(self):
+ if not self.in_keys:
+ parent = getattr(self, "parent", None)
+ if parent is not None:
+ if getattr(parent, "base_env", None) is not None:
+ if getattr(parent.base_env, "input_mode", None) == "history":
+ self.in_keys = [("history", "full"), "answer"]
+ elif getattr(parent.base_env, "input_mode", None) == "text":
+ self.in_keys = [("text", "full"), "answer"]
+ elif getattr(parent.base_env, "input_mode", None) == "tokens":
+ self.in_keys = [("tokens", "full"), "answer"]
+ else:
+ raise ValueError(f"No base env found for {self}")
+
+ def set_container(self, container: Transform | EnvBase) -> None:
+ result = super().set_container(container)
+ self._maybe_get_in_keys()
+ return result
+
+ _input_mode = None
+
+ @property
+ def input_mode(self):
+ if self._input_mode is None:
+ input_mode = (
+ getattr(self.parent, "input_mode", "history")
+ if hasattr(self, "parent") and self.parent is not None
+ else "history"
+ )
+ self._input_mode = input_mode
+ return self._input_mode
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
@@ -72,26 +113,72 @@ def _step(
# did update in place
return next_tensordict
- # Get the completion
+ # Get the completion based on input_mode
+ self._maybe_get_in_keys()
responses = tensordict[self.in_keys[0]] # batch_size, grpo_size, L
- if isinstance(responses, str):
- responses = [responses for _ in range(next_tensordict.batch_size[0])]
+ # Handle different response types based on input_mode
+ input_mode = self.input_mode
+ if input_mode == "history":
+ # responses is a History object, extract the text content
+ responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
+ if hasattr(responses, "content"):
+ # If it's a History object with content attribute
+ text_completion = responses.content
+ if is_non_tensor(text_completion):
+ text_completion = text_completion.tolist()
+ if not isinstance(text_completion, list):
+ text_completion = [text_completion]
+ elif hasattr(responses, "apply_chat_template"):
+ # If it's a History object, apply chat template to get text
+ text_completion = responses.apply_chat_template(
+ tokenizer=self.tokenizer, add_generation_prompt=False
+ )
+ if not isinstance(text_completion, list):
+ text_completion = [text_completion]
+ else:
+ # Fallback: try to convert to string
+ text_completion = [str(responses)]
+ elif input_mode == "text":
+ # responses is already text
+ if isinstance(responses, str):
+ text_completion = [
+ responses for _ in range(next_tensordict.batch_size[0])
+ ]
+ elif not isinstance(responses, list):
+ text_completion = [responses]
+ else:
+ text_completion = responses
+ elif input_mode == "tokens":
+ # responses is tokens, need to decode
+ if isinstance(responses, torch.Tensor):
+ if responses.ndim == 3:
+ batch_size, grpo_size, _ = responses.shape
+ # decode
+ text_completion = self.tokenizer.decode(
+ responses.flatten(0, 1).tolist()
+ )
+ if not isinstance(text_completion, list):
+ text_completion = [
+ text_completion for _ in range(next_tensordict.batch_size[0])
+ ]
+ else:
+ # Assume it's already a list of token sequences
+ text_completion = []
+ for token_seq in responses:
+ if isinstance(token_seq, torch.Tensor):
+ text_completion.append(
+ self.tokenizer.decode(token_seq.tolist())
+ )
+ else:
+ text_completion.append(str(token_seq))
+ else:
+ raise ValueError(f"Unknown input_mode: {input_mode}")
if self.eos_token is not None:
- responses = [r.removesuffix(self.eos_token) for r in responses]
+ text_completion = [r.removesuffix(self.eos_token) for r in text_completion]
answers = next_tensordict[self.in_keys[1]] # batch_size, grpo_size
- if isinstance(responses, torch.Tensor):
- if responses.ndim == 3:
- batch_size, grpo_size, _ = responses.shape
- # decode
- text_completion = self.tokenizer.decode(responses.flatten(0, 1).tolist())
- else:
- text_completion = responses
- if not isinstance(text_completion, list):
- text_completion = [
- text_completion for _ in range(next_tensordict.batch_size[0])
- ]
+
# Decomposed reward
tds = []
# torchrl_logger.info(f"{answers=}")
@@ -114,10 +201,13 @@ def _step(
# With tensorclass comparison should be easy
cot_orig, answer = answer.split("#### ")
tds.append(
- self._single_shaped_correctness_reward(answer, potential_answer, cot)
+ self._single_shaped_correctness_reward(
+ answer, [potential_answer], [cot]
+ )
)
tds = torch.stack(tds)
if isinstance(responses, torch.Tensor) and responses.ndim == 3:
+ batch_size, grpo_size, _ = responses.shape
tds = tds.reshape(batch_size, grpo_size)
# Rewards need to have shape broadcastable to [batch x tokens x 1]
tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))
@@ -220,7 +310,13 @@ def extract_tags(text: str) -> tuple[str, str]:
except ET.ParseError:
return ("", "")
+ think_elem = root.find("think")
+ answer_elem = root.find("answer")
return (
- root.find("think").text if root.find("think") is not None else "",
- root.find("answer").text if root.find("answer") is not None else "",
+ think_elem.text
+ if think_elem is not None and think_elem.text is not None
+ else "",
+ answer_elem.text
+ if answer_elem is not None and answer_elem.text is not None
+ else "",
)
diff --git a/torchrl/envs/llm/reward/ifeval/_instructions_main.py b/torchrl/envs/llm/reward/ifeval/_instructions_main.py
index d891ae823f8..6ce14ea3175 100644
--- a/torchrl/envs/llm/reward/ifeval/_instructions_main.py
+++ b/torchrl/envs/llm/reward/ifeval/_instructions_main.py
@@ -40,6 +40,10 @@ def _test_instruction_following_strict(
):
"""Tests response to see if instructions are followed."""
instruction_list = inp.instruction_id_list
+ if not isinstance(instruction_list, list):
+ raise ValueError(
+ f"instruction_list must be a list, got {type(instruction_list)}, {instruction_list=}"
+ )
is_following_list = []
for index, instruction_id in enumerate(instruction_list):
diff --git a/torchrl/envs/llm/reward/ifeval/_scorer.py b/torchrl/envs/llm/reward/ifeval/_scorer.py
index 0ebef1a76c7..40830b1dc84 100644
--- a/torchrl/envs/llm/reward/ifeval/_scorer.py
+++ b/torchrl/envs/llm/reward/ifeval/_scorer.py
@@ -20,7 +20,14 @@
from typing import Callable
import torch
-from tensordict import NestedKey, NonTensorData, TensorClass, TensorDict, TensorDictBase
+from tensordict import (
+ lazy_stack,
+ NestedKey,
+ NonTensorData,
+ TensorClass,
+ TensorDict,
+ TensorDictBase,
+)
from tensordict.tensorclass import is_non_tensor
from torchrl._utils import logger as torchrl_logger
@@ -40,6 +47,27 @@ class IFEvalScoreData(TensorClass):
prompt_level_loose_acc: torch.Tensor | None
inst_level_loose_acc: torch.Tensor | None
+ @classmethod
+ def default_spec(
+ cls, shape: torch.Size, device: torch.device | None = None
+ ) -> Composite:
+ return Composite(
+ prompt_level_strict_acc=Unbounded(
+ shape=shape, dtype=torch.bool, device=device
+ ),
+ inst_level_strict_acc=Unbounded(
+ shape=shape, dtype=torch.bool, device=device
+ ),
+ prompt_level_loose_acc=Unbounded(
+ shape=shape, dtype=torch.bool, device=device
+ ),
+ inst_level_loose_acc=Unbounded(
+ shape=shape, dtype=torch.bool, device=device
+ ),
+ data_cls=cls,
+ step_mdp_static=True,
+ )
+
def __post_init__(self):
prompt_level_loose_acc = self.get(
"prompt_level_loose_acc", as_padded_tensor=True
@@ -72,7 +100,10 @@ def __post_init__(self):
def _process_results(
- data: TensorDict, response: str | NonTensorData, verbose: bool = False
+ data: TensorDict,
+ response: str | NonTensorData,
+ verbose: bool = False,
+ prompt: str | None = None,
) -> IFEvalScoreData:
if not _has_langdetect:
raise ImportError("langdetect must be installed to user IFEvalScorer.")
@@ -85,10 +116,13 @@ def _process_results(
_test_instruction_following_strict,
)
+ if prompt is None:
+ prompt = data["text"]
+
inp = _InputExample(
key=data["key"],
instruction_id_list=data["instruction_id_list"],
- prompt=data["text"],
+ prompt=prompt if prompt is not None else "",
kwargs=data["kwargs"],
)
@@ -136,6 +170,7 @@ class IfEvalScorer(Transform):
`prompt_level_loose_acc`, `inst_level_loose_acc`, in that order). Defaults to `[0.4, 0.3, 0.2, 0.1]`.
This is only used if `aggregate_reward` is `True` and the default aggregator is used.
verbose (bool, optional): Whether to print verbose information. Defaults to `False`.
+ set_done_if_answer (bool): whether to set the done flag to `True` when an answer is present. Defaults to `True`.
.. note:: `IFEvalScorer` requires the following libraries to be installed: `langdetect`, `nltk` and `immutabledict`.
@@ -156,9 +191,11 @@ def __init__(
] = True,
format_weights: list[float] | None = None,
verbose: bool = False,
+ set_done_if_answer: bool = True,
):
self.aggregate_reward = aggregate_reward
self.score_key = score_key
+ self.set_done_if_answer = set_done_if_answer
out_keys = [self.score_key]
if aggregate_reward:
out_keys.append("reward")
@@ -193,7 +230,7 @@ def default_reward_aggregator(
answer_blocks: list[str] | None = None,
complete: bool | torch.Tensor | None = None,
) -> torch.Tensor:
- r"""Default reward aggregation function that provides a more nuanced scoring system.
+ r"""Improved reward aggregation function with tiered multiplicative scoring.
Args:
score (IFEvalScoreData): The score data.
@@ -201,45 +238,56 @@ def default_reward_aggregator(
answer_blocks (list[str], optional): The list of answer blocks.
complete (bool, optional): Whether the response is complete (ends with a eos token).
- The reward is composed of three main components:
- 1. Format score (max 1.0):
- - prompt_level_strict_acc: 0.4 (highest weight for strict adherence to all instructions)
- - inst_level_strict_acc: 0.3 (high weight for strict adherence to individual instructions)
- - prompt_level_loose_acc: 0.2 (medium weight for loose adherence to all instructions)
- - inst_level_loose_acc: 0.1 (lowest weight for loose adherence to individual instructions)
- All instruction-level metrics are averaged to ensure balanced contribution.
-
- 2. Structure score (max 1.0):
- - think_block: 0.5 (presence of exactly one think block)
- - answer_block: 0.5 (presence of exactly one answer block)
-
- 3. Completion bonus (max 0.2):
- - complete: 0.2 (response ends with eos token)
+ The reward uses a tiered multiplicative system:
- The overall formula for the reward is:
+ 1. Critical failure check: No answer blocks = 0 reward
+ 2. Base format score (0-1): Weighted average of format metrics
+ 3. Structure multiplier (0.1-1.0): Penalties for missing/multiple blocks
+ 4. Quality bonus (0-0.5): Rewards for high quality and completion
+ 5. Task complexity scaling: More requirements = higher potential rewards
- .. math::
+ The final formula is:
+ reward = (format_score + quality_bonus) * structure_multiplier * complexity_scale
- reward = format\_score + structure\_score + completion\_bonus
+ This provides better learning signals by:
+ - Requiring critical elements (answer tags) for meaningful rewards
+ - Using multiplicative scaling to reward doing everything well
+ - Scaling rewards based on task complexity
+ - Providing clear failure modes and success incentives
- Therefore, the maximum value the reward can take is 2.2, with:
- - 1.0 from format adherence
- - 1.0 from structural elements (think/answer blocks)
- - 0.2 from completion bonus
+ Reward range: 0.0 to ~1.5-2.7 depending on task complexity (more instructions = higher max reward).
"""
default_dtype = torch.get_default_dtype()
score = score.to(default_dtype)
- # Format score calculation - using mean for instruction-level metrics
+ # Critical failure check - no answer = no reward
+ if not answer_blocks:
+ return torch.zeros(
+ score.batch_size + (1,), device=score.device, dtype=default_dtype
+ )
+
+ # Base format score calculation (0-1)
format_components = torch.stack(
[
- score.prompt_level_strict_acc.sum(-1, keepdim=True), # Single value
- score.inst_level_strict_acc.mean(
- -1, keepdim=True
+ score.prompt_level_strict_acc.sum(-1, keepdim=True)
+ if score.prompt_level_strict_acc is not None
+ else torch.zeros(
+ score.batch_size + (1,), device=score.device, dtype=default_dtype
+ ), # Single value
+ score.inst_level_strict_acc.mean(-1, keepdim=True)
+ if score.inst_level_strict_acc is not None
+ else torch.zeros(
+ score.batch_size + (1,), device=score.device, dtype=default_dtype
), # Average across instructions
- score.prompt_level_loose_acc.sum(-1, keepdim=True), # Single value
- score.inst_level_loose_acc.mean(
- -1, keepdim=True
+ score.prompt_level_loose_acc.sum(-1, keepdim=True)
+ if score.prompt_level_loose_acc is not None
+ else torch.zeros(
+ score.batch_size + (1,), device=score.device, dtype=default_dtype
+ ), # Single value
+ score.inst_level_loose_acc.mean(-1, keepdim=True)
+ if score.inst_level_loose_acc is not None
+ else torch.zeros(
+ score.batch_size + (1,), device=score.device, dtype=default_dtype
), # Average across instructions
],
-1,
@@ -247,33 +295,55 @@ def default_reward_aggregator(
weights = torch.tensor(
self.format_weights,
device=format_components.device,
- dtype=torch.get_default_dtype(),
+ dtype=default_dtype,
)
format_score = (format_components * weights).sum(dim=-1, keepdim=True)
- # Structure score calculation
- if think_blocks is not None:
- think_score = float(len(think_blocks) == 1) * 0.5
- else:
- think_score = 0.0
+ # Structure multiplier (0.1-1.0)
+ structure_multiplier = 1.0
- if answer_blocks is not None:
- answer_score = float(len(answer_blocks) == 1) * 0.5
- else:
- answer_score = 0.0
+ # Heavy penalty for missing think blocks (but not zero)
+ if not think_blocks:
+ structure_multiplier *= 0.3
+ elif len(think_blocks) > 1:
+ structure_multiplier *= 0.7 # Penalty for multiple think blocks
+
+ # Penalty for multiple answer blocks
+ if len(answer_blocks) > 1:
+ structure_multiplier *= 0.7
- structure_score = think_score + answer_score
+ # Quality bonus (0-0.5)
+ quality_bonus = torch.zeros_like(format_score)
+
+ # Bonus for high quality responses
+ if format_score > 0.8:
+ quality_bonus += 0.3
# Completion bonus
- if complete is None:
- completion_bonus = 0.0
- elif isinstance(complete, torch.Tensor):
- completion_bonus = complete.to(default_dtype) * 0.2
+ if complete is not None:
+ if isinstance(complete, torch.Tensor):
+ completion_bonus = complete.to(default_dtype) * 0.2
+ else:
+ completion_bonus = float(complete) * 0.2
+ quality_bonus += completion_bonus
+
+ # Task complexity scaling based on number of instructions
+ # More instructions = higher potential rewards
+ if (
+ score.inst_level_strict_acc is not None
+ and score.inst_level_strict_acc.numel() > 0
+ ):
+ num_instructions = score.inst_level_strict_acc.shape[-1]
else:
- completion_bonus = float(complete) * 0.2
-
- # Combine all components
- final_reward = format_score + structure_score + completion_bonus
+ num_instructions = 1
+ complexity_scale = (
+ 1.0 + (num_instructions - 1) * 0.2
+ ) # 1.0 for 1 instruction, 1.2 for 2, etc.
+
+ # Final reward: (format + quality) * structure_multiplier * complexity_scale
+ final_reward = (
+ (format_score + quality_bonus) * structure_multiplier * complexity_scale
+ )
final_reward = final_reward.to(default_dtype)
return final_reward
@@ -281,8 +351,11 @@ def default_reward_aggregator(
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
+ if not getattr(self.parent.base_env, "input_mode", "history") == "history":
+ raise ValueError("IFEvalScorer only supports history input mode")
+
if tensordict.ndim:
- return torch.stack(
+ return lazy_stack(
[
self._step(td, next_td)
for td, next_td in zip(
@@ -290,7 +363,8 @@ def _step(
)
]
)
- h = next_tensordict["history"][..., -1]
+ h = tensordict["history", "full"][..., -1]
+ prompt = tensordict["history", "prompt"][..., -1].content
response = h.content
complete = h.is_complete
# response = tensordict.get(self.response_key)
@@ -311,6 +385,7 @@ def _step(
tensordict.copy().auto_device_(),
answer_blocks[0] if answer_blocks else "",
verbose=self.verbose,
+ prompt=prompt,
)
next_tensordict.set(
self.score_key,
@@ -327,8 +402,31 @@ def _step(
answer_blocks=answer_blocks,
complete=complete,
)
+ reward = reward.view(
+ next_tensordict.batch_size
+ + (
+ 1,
+ 1,
+ )
+ )
next_tensordict.set("reward", reward)
-
+ if self.set_done_if_answer and bool(answer_blocks):
+ next_tensordict.set(
+ "done",
+ torch.ones(
+ next_tensordict.batch_size + (1,),
+ device=next_tensordict.device,
+ dtype=torch.bool,
+ ),
+ )
+ next_tensordict.set(
+ "terminated",
+ torch.ones(
+ next_tensordict.batch_size + (1,),
+ device=next_tensordict.device,
+ dtype=torch.bool,
+ ),
+ )
return next_tensordict
@property
@@ -343,23 +441,14 @@ def expected_keys(self) -> list[str]:
def transform_reward_spec(self, reward_spec: Composite) -> Composite:
reward_spec["reward"] = Unbounded(
- reward_spec.shape + (1,), dtype=torch.get_default_dtype()
+ reward_spec.shape + (1, 1),
+ dtype=torch.get_default_dtype(),
+ device=reward_spec.device,
)
return reward_spec
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
- observation_spec[self.score_key] = Composite(
- prompt_level_strict_acc=Unbounded(
- shape=observation_spec.shape, dtype=torch.bool
- ),
- inst_level_strict_acc=Unbounded(
- shape=observation_spec.shape, dtype=torch.bool
- ),
- prompt_level_loose_acc=Unbounded(
- shape=observation_spec.shape, dtype=torch.bool
- ),
- inst_level_loose_acc=Unbounded(
- shape=observation_spec.shape, dtype=torch.bool
- ),
+ observation_spec[self.score_key] = IFEvalScoreData.default_spec(
+ observation_spec.shape, device=observation_spec.device
)
return observation_spec
diff --git a/torchrl/envs/llm/transforms/__init__.py b/torchrl/envs/llm/transforms/__init__.py
index 7502ba5f131..6e28b1ac18e 100644
--- a/torchrl/envs/llm/transforms/__init__.py
+++ b/torchrl/envs/llm/transforms/__init__.py
@@ -6,7 +6,7 @@
from .browser import BrowserTransform
from .dataloading import as_nested_tensor, as_padded_tensor, DataLoadingPrimer
from .format import TemplateTransform
-from .kl import KLRewardTransform, RetrieveLogProb
+from .kl import KLComputation, KLRewardTransform, RetrieveKL, RetrieveLogProb
from .policy_version import PolicyVersion
from .reason import AddThinkingPrompt
from .tokenizer import Tokenizer
@@ -17,10 +17,12 @@
"DataLoadingPrimer",
"KLRewardTransform",
"RetrieveLogProb",
+ "RetrieveKL",
"MCPToolTransform",
"PolicyVersion",
"PythonInterpreter",
"AddThinkingPrompt",
+ "KLComputation",
"TemplateTransform",
"Tokenizer",
"as_nested_tensor",
diff --git a/torchrl/envs/llm/transforms/dataloading.py b/torchrl/envs/llm/transforms/dataloading.py
index 0d09dc4c992..770948f1e14 100644
--- a/torchrl/envs/llm/transforms/dataloading.py
+++ b/torchrl/envs/llm/transforms/dataloading.py
@@ -464,24 +464,43 @@ def _endless_iter(self, obj):
while True:
yield from obj
+ _device: torch.device | None = None
+
+ @property
+ def device(self) -> torch.device | None:
+ if self._device is None:
+ primers = getattr(self, "primers", None)
+ if primers is not None:
+ device = self.primers.device
+ else:
+ parent = getattr(self, "parent", None)
+ if parent is not None:
+ device = getattr(parent, "device", None)
+ else:
+ device = None
+ self._device = device
+ return self._device
+
+ @device.setter
+ def device(self, device: torch.device | None):
+ self._device = device
+
def _load_from_dataloader(self, reset: torch.Tensor | None = None):
"""Loads a single element from the dataloader, or alternatively from the buffer.
If `reset` is passed, then one element per reset will be loaded.
"""
+ device = self.device
+
if reset is not None:
if not reset.any():
raise RuntimeError("reset must have at least one True value.")
if reset.ndim > 0:
- loaded = [self._load_from_dataloader() for _ in range(reset.sum())]
+ loaded = [
+ self._load_from_dataloader().to(device) for _ in range(reset.sum())
+ ]
return self.stack_method(loaded)
- primers = getattr(self, "primers", None)
- if primers is not None:
- device = self.primers.device
- else:
- device = None
-
if len(self._queue) > 0:
result = self._queue.popleft()
if result.device != device:
diff --git a/torchrl/envs/llm/transforms/kl.py b/torchrl/envs/llm/transforms/kl.py
index a8a13798d0f..03fd7470f50 100644
--- a/torchrl/envs/llm/transforms/kl.py
+++ b/torchrl/envs/llm/transforms/kl.py
@@ -4,20 +4,21 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
-import contextlib
-import gc
+import warnings
+from contextlib import nullcontext
from copy import copy
+from typing import Any, Literal
import torch
from tensordict import NestedKey, set_list_to_stack, TensorDictBase, unravel_key
-from tensordict.nn import ProbabilisticTensorDictModule
-from tensordict.utils import _zip_strict, is_seq_of_nested_key
+from tensordict.utils import _zip_strict, is_seq_of_nested_key, logger as torchrl_logger
+from torch.nn.utils.rnn import pad_sequence
from torchrl.data import Composite, Unbounded
-from torchrl.data.llm.chat import History
from torchrl.envs import EnvBase, Transform
+from torchrl.envs.transforms.transforms import Compose
from torchrl.envs.transforms.utils import _set_missing_tolerance
-from torchrl.modules.llm.policies.common import CategoricalSequential
+from torchrl.modules.llm.policies.common import LLMWrapperBase
try:
import transformers
@@ -26,50 +27,65 @@
class KLRewardTransform(Transform):
- """A transform to add a KL[pi_current||pi_0] correction term to the reward.
+ """A legacy transform for computing KL divergence-based rewards.
- This transform is used to constrain the policy to remain close to its original
- configuration which limits overfitting when fine-tuning using RLHF.
+ **Deprecated**: This transform is maintained for backward compatibility but is no longer
+ the recommended approach. Use :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` instead,
+ which provides better modularity and integration with the new wrapper design.
+
+ **Recent Changes:**
+ - **Legacy Status**: This transform is now considered legacy and may not work optimally
+ with the new modular wrapper design.
+ - **ChatHistory Integration**: Limited support for the new :class:`~torchrl.modules.llm.policies.ChatHistory` objects.
+ - **Input Mode Support**: May not handle all input modes (`"history"`, `"text"`, `"tokens"`) consistently.
+
+ **Recommendation**:
+ Use :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` for new code, which provides:
+ - Better integration with the new wrapper design
+ - Consistent support for all input modes
+ - Proper handling of ChatHistory objects
+ - More modular and composable architecture
Args:
- actor (ProbabilisticTensorDictModule): a frozen probabilistic actor. It must
- have the following features: it must have a set of input (``in_keys``)
- and output keys (``out_keys``). It must have a ``get_dist`` method
- that outputs the distribution of the action.
- coef (:obj:`float`): the coefficient of the KL term. Defaults to ``1.0``.
- in_keys (str or list of str/tuples of str): the input key where the
- reward should be fetched. Defaults to ``"reward"``.
- out_keys (str or list of str/tuples of str): the output key where the
- reward should be written. Defaults to ``["reward", "kl_penalty", "ref_log_prob"]``.
- add_to_reward (bool): whether to add the reward term to the reward.
- Defaults to ``True``.
-
- .. note:: If the parameters are not differentiable (default), they will *not*
- follow the module when dtype or device casting operations will be called
- (such as :meth:`cuda`, :meth:`to` etc.). When ``requires_grad=True``,
- casting operations will work as expected.
+ gen_model (LLMWrapperBase): the generation model.
+ ref_model (LLMWrapperBase): the reference model.
- Examples:
- TODO
+ Keyword Args:
+ assistant_only (bool): whether to only compute KL on assistant tokens. Defaults to `True`.
+ tokenizer (transformers.AutoTokenizer): the tokenizer to use. Defaults to `None`.
+ detach (bool): whether to detach the KL from the computation graph. Defaults to `True`.
+ device (torch.device): the device to use. Defaults to `None`.
+ padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`.
- .. note:: Because the KL formula is not always available and the parameters of the
- original distribution may not have been recorded, we use a stochastic estimate
- of the KL divergence.
+ Examples:
+ >>> # Legacy usage (not recommended for new code)
+ >>> transform = KLRewardTransform(gen_model, ref_model)
+ >>>
+ >>> # Recommended approach using RetrieveKL
+ >>> from torchrl.envs.llm.transforms.kl import RetrieveKL
+ >>> transform = RetrieveKL(gen_model, ref_model, assistant_only=True)
+ .. seealso::
+ :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL`: The recommended transform for KL divergence computation.
+ :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb`: Base transform for retrieving log-probabilities.
+ :class:`~torchrl.envs.llm.transforms.kl.KLComputation`: Transform for computing KL divergence between log-prob tensors.
"""
DEFAULT_IN_KEYS = ["reward"]
def __init__(
self,
- actor: ProbabilisticTensorDictModule,
+ ref_model: LLMWrapperBase,
+ *,
coef=1.0,
in_keys=None,
out_keys=None,
- log_prob_key: NestedKey = "log_probs",
- action_key: NestedKey | None = None,
+ log_prob_key: NestedKey = ("log_probs", "full"),
device: torch.device | None = None,
add_to_reward: bool = True,
+ tokenizer: transformers.AutoTokenizer | None = None,
+ assistant_only: bool = True,
+ padding_side: str = "left",
):
if in_keys is None:
in_keys = self.DEFAULT_IN_KEYS
@@ -94,31 +110,55 @@ def __init__(
)
self._out_keys = [unravel_key(out_key) for out_key in self._out_keys]
+ if getattr(ref_model, "generate", False):
+ raise ValueError(
+ "The actor is configured to generate text, not compute the log-probs."
+ )
+
# update the in_keys for dispatch etc
- self.in_keys = self.in_keys + actor.in_keys
+ self.in_keys = self.in_keys + ref_model.in_keys
self.in_keys = [unravel_key(in_key) for in_key in self.in_keys]
self.add_to_reward = add_to_reward
# check that the model has parameters
- self.__dict__["actor"] = actor
+ self.__dict__["ref_model"] = ref_model
# self._buffers["actor_params"] = params.clone().detach()
self.device = device
- self.action_key = action_key
# find the sample log-prob key
- self.sample_log_prob_key = log_prob_key
-
- def find_sample_log_prob(module):
- if hasattr(module, "log_prob_key"):
- self.sample_log_prob_key = module.log_prob_key
+ self.log_prob_full_key = log_prob_key
- self.actor.apply(find_sample_log_prob)
+ self._tokenizer = tokenizer
+ self.assistant_only = assistant_only
+ self.padding_side = padding_side
if not isinstance(coef, torch.Tensor):
coef = torch.as_tensor(coef)
self.register_buffer("coef", coef)
+ # sanity check for the ref_model
+ if not getattr(ref_model, "input_mode", "tokens") == "tokens":
+ raise ValueError(
+ "The ref_model must be configured to use tokens as input. Please set the `input_mode` argument to `tokens`."
+ )
+
+ @property
+ def pad_output(self):
+ # We need pad_output to match the pad_output of the inference model
+ return self.ref_model.pad_output
+
+ @property
+ def tokenizer(self):
+ tokenizer = self._tokenizer
+ if tokenizer is not None:
+ return tokenizer
+ try:
+ return self.ref_model.tokenizer
+ except AttributeError:
+ raise AttributeError(
+ "The ref_model does not have a tokenizer. Please pass the tokenizer to the constructor."
+ )
def set_container(self, container: Transform | EnvBase) -> None:
result = super().set_container(container)
@@ -141,54 +181,127 @@ def _reset(
tensordict_reset = self._step(tensordict_reset, tensordict_reset)
return tensordict_reset
+ @property
+ def action_key(self) -> NestedKey:
+ # Get the action from the base env (a ChatEnv).
+ if self.parent.base_env.input_mode == "history":
+ return ("history", "full")
+ if self.parent.base_env.input_mode == "text":
+ return ("text", "full")
+ if self.parent.base_env.input_mode == "tokens":
+ return ("tokens", "full")
+ raise ValueError(f"Invalid input mode: {self.parent.base_env.input_mode}")
+
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
- # run the actor on the tensordict
- action_key = self.action_key
- if action_key is None:
- raise ValueError(
- f"action_key is required. Please set a parent for the {type(self).__name__} to recover the action keys automatically, "
- f"or pass the action_key argument directly to {type(self).__name__} constructor."
- )
- response_txt = tensordict.get(action_key, None)
- if response_txt is None:
+ if self.device is not None:
+ tensordict = tensordict.to(self.device)
+ next_tensordict = next_tensordict.to(self.device)
+ # tensordict = self._get_text_response(tensordict, next_tensordict)
+ response = tensordict.get(self.action_key, None)
+ if response is None:
if not self.missing_tolerance:
raise RuntimeError(
- f"Action with key {action_key} not found data {tensordict}"
+ f"Action with key {self.action_key} not found data {tensordict}"
)
# being called after reset or without action, skipping
if self.out_keys[0] != "reward" and self.parent is not None:
next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero())
return next_tensordict
- if hasattr(self.actor, "log_prob"):
- if self.device is not None and tensordict.device != self.device:
- td_device = tensordict.to(self.device)
- else:
- td_device = tensordict.copy()
- ref_log_prob = self.actor.log_prob(
- td_device, as_nested_tensor=True, layout=torch.strided
+
+ # We use the ("tokens", "full") key to get the log-probs of the reference model
+ with torch.device(self.device) if self.device is not None else nullcontext():
+ td_input = tensordict.copy()
+ ref_log_prob_td = self.ref_model(td_input)
+ if self.pad_output:
+ ref_log_prob_padded = ref_log_prob_td.get(self.log_prob_full_key)
+ else:
+ ref_log_prob_unpadded = ref_log_prob_td.get(
+ self.log_prob_full_key, as_list=True
)
+ if self.assistant_only:
+ # Get the assistant mask
+ mask = tensordict.get(("masks", "all_assistant_mask"))
+ # mask will often be None - fall back on prompt / response separation
+ if mask is None:
+ if self.pad_output:
+ # simple case: just take the prompt length
+ prompt_length = tensordict.get(("tokens", "prompt")).shape[-1]
+ mask = tensordict.get(("masks", "all_attention_mask")).clone()
+ mask[..., :prompt_length] = False
+ else:
+ # simple case: just take the prompt length
+ prompt_length = [
+ t.size(-1)
+ for t in tensordict.get(("tokens", "prompt"), as_list=True)
+ ]
+ mask = tensordict.get(("masks", "all_attention_mask"), as_list=True)
+ for i in range(len(prompt_length)):
+ mask[i] = mask[i].clone()
+ mask[i][..., : prompt_length[i]] = False
+
+ # we want to keep the batch dimension
+ ref_log_prob_list = []
+ if self.pad_output:
+ for i in range(ref_log_prob_padded.size(0)):
+ ref_log_prob_list.append(
+ ref_log_prob_padded[i].masked_fill(~mask[i], 0)
+ )
+ else:
+ for i in range(len(ref_log_prob_unpadded)):
+ ref_log_prob_list.append(
+ ref_log_prob_unpadded[i].masked_fill(~mask[i], 0)
+ )
+ if self.pad_output:
+ ref_log_prob = pad_sequence(
+ ref_log_prob_list,
+ batch_first=True,
+ padding_value=0,
+ padding_side=self.padding_side,
+ )
+ else:
+ ref_log_prob = torch.nested.nested_tensor(
+ ref_log_prob_list, layout=torch.strided
+ )
+
+ # we obtain the current log-probs (already computed) from the current tensordict
+ if self.pad_output:
+ curr_log_prob_padded = tensordict.get(self.log_prob_full_key)
else:
- ref_log_prob_td = self.actor(tensordict)
- ref_log_prob = ref_log_prob_td.get(self.sample_log_prob_key)
+ curr_log_prob_unpadded = tensordict.get(
+ self.log_prob_full_key, as_list=True
+ )
+ if self.assistant_only:
+ # we want to keep the batch dimension
+ curr_log_prob_list = []
+ if self.pad_output:
+ for i in range(curr_log_prob_padded.size(0)):
+ curr_log_prob_list.append(
+ curr_log_prob_padded[i].masked_fill(~mask[i], 0)
+ )
+ else:
+ for i in range(len(curr_log_prob_unpadded)):
+ curr_log_prob_list.append(
+ curr_log_prob_unpadded[i].masked_fill(~mask[i], 0)
+ )
+ if self.pad_output:
+ curr_log_prob = pad_sequence(
+ curr_log_prob_list,
+ batch_first=True,
+ padding_value=0,
+ padding_side=self.padding_side,
+ )
+ else:
+ curr_log_prob = torch.nested.nested_tensor(
+ curr_log_prob_list, layout=torch.strided
+ )
- reward_key = self.in_keys[0]
- reward = next_tensordict.get(reward_key)
- curr_log_prob = tensordict.get(
- self.sample_log_prob_key, as_nested_tensor=True, layout=torch.strided
- )
ref_log_prob = ref_log_prob.to(curr_log_prob.device)
# We want the log-probs to have a similar dim to the reward
curr_log_prob = curr_log_prob.unsqueeze(-1)
ref_log_prob = ref_log_prob.unsqueeze(-1)
- # we use the unbiased consistent estimator of the KL: log_p(x) - log_q(x) when x ~ p(x)
- if not reward.is_nested and ref_log_prob.is_nested:
- reward = torch.nested.nested_tensor(
- [rew.expand(lp.shape) for rew, lp in zip(reward, ref_log_prob)],
- layout=torch.strided,
- )
for i in range(ref_log_prob.size(0)):
if ref_log_prob[i].shape != curr_log_prob[i].shape:
# Don't check shapes if nested
@@ -197,16 +310,25 @@ def _step(
f"One possible reason is that the padding token is identical to the eos token, which means that the eos_token log_prob is truncated from the "
f"reference model output."
)
- if reward is not None and reward.ndim != curr_log_prob.ndim:
- raise ValueError(
- "The number of dimensions of reward must be the same as the number of dimensions of the KL "
- f"term. Got ndim={reward.ndim} and {curr_log_prob.ndim} respectively."
- )
kl = curr_log_prob - ref_log_prob
if self.add_to_reward:
+ reward_key = self.in_keys[0]
+ reward = next_tensordict.get(reward_key)
+ # we use the unbiased consistent estimator of the KL: log_p(x) - log_q(x) when x ~ p(x)
+ if not reward.is_nested and ref_log_prob.is_nested:
+ reward = torch.nested.nested_tensor(
+ [rew.expand(lp.shape) for rew, lp in zip(reward, ref_log_prob)],
+ layout=torch.strided,
+ )
+ if reward is not None and reward.ndim != curr_log_prob.ndim:
+ raise ValueError(
+ "The number of dimensions of reward must be the same as the number of dimensions of the KL "
+ f"term. Got ndim={reward.ndim} and {curr_log_prob.ndim} respectively."
+ )
if reward is None:
reward = 0
- next_tensordict.set(self.out_keys[0], reward - self.coef * kl)
+ reward = reward - self.coef * kl
+ next_tensordict.set(self.out_keys[0], reward)
next_tensordict.set(self.out_keys[1], kl)
next_tensordict.set(self.out_keys[2], ref_log_prob)
return next_tensordict
@@ -282,36 +404,44 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
class RetrieveLogProb(Transform):
- """A transform to retrieve the log-probs of a text given a reference model.
+ """A transform to retrieve log-probabilities from a model for KL divergence computation.
+
+ This transform computes log-probabilities from a reference model, which can then be used
+ to compute KL divergence with another model's log-probabilities. It's designed to work
+ with the :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` and :class:`~torchrl.envs.llm.transforms.kl.KLComputation` transforms.
Args:
- actor (CategoricalSequential): the reference model.
+ model (LLMWrapperBase): the model to use to compute the log-probs.
Keyword Args:
- history_key (NestedKey): the key where the history is stored. Defaults to `"history"`.
- log_prob_key (NestedKey): the key where the log-probs are stored. Defaults to `"ref_log_prob"`.
+ log_probs_full_key (NestedKey): the key where the log-probs are stored.
+ If not provided, the key will be retrieved from the model's `log_probs_key` attribute
+ (i.e., `(model.log_probs_key, "full")`).
assistant_only (bool): whether to only retrieve the log-probs of the assistant tokens (i.e., steps of history
- where the role is `"assistant"`). Defaults to `False`.
+ where the role is `"assistant"`). Defaults to `True`.
- .. note:: The template must accommodate the `return_assistant_tokens_mask` keyword argument.
- This may not be the case for all templates. In this case, you can pass a custom template to the `apply_chat_template` method
- via the `tokenizer_kwargs` argument: `tokenizer_kwargs = {"chat_template_name": "qwen"}` or `tokenizer_kwargs = {"chat_template": my_template}.
+ .. note:: When `assistant_only=True`, the model must have `input_mode='history'` to properly identify
+ assistant tokens. For other input modes (`"text"` or `"tokens"`), set `assistant_only=False`.
+ This ensures users are conscious of the limitation that assistant token identification requires
+ structured conversation history.
tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`.
- To control the tokenization in the actor, pass the tokenizer kwargs to the actor constructor.
- Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_tensors": "pt", "padding": True, "add_generation_prompt": False}`.
- tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor`.
+ To control the tokenization in the ref_model, pass the tokenizer kwargs to the ref_model constructor.
+ Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_dict": True, "padding": False, "add_generation_prompt": False}`.
+ tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `ref_model`.
detach (bool): whether to exclude the log-probs from the gradient computation. Defaults to `True`.
device (torch.device): the device to use for tensor creation. Defaults to `None`.
+ padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`.
Examples:
- >>> from torchrl.data.llm.chat import History, _CHAT_TEMPLATES
+ >>> from torchrl.data.llm import History
>>> from torchrl.modules.llm import TransformersWrapper
- >>> from torchrl.objectives.llm.sft import SFTLoss
+ >>> from torchrl.modules.llm.policies import ChatHistory
>>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
- >>> from tensordict import TensorDict, lazy_stack, set_list_to_stack
+ >>> from tensordict import TensorDict, set_list_to_stack
>>> import torch
>>>
+ >>> # Set up list to stack for History
>>> set_list_to_stack(True).set()
>>>
>>> # Create chat data
@@ -334,174 +464,747 @@ class RetrieveLogProb(Transform):
>>> # Setup tokenizer and model
>>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
>>> tokenizer.pad_token = tokenizer.eos_token
- >>> tokenizer.chat_template = _CHAT_TEMPLATES["chatml_format"]
>>> model = OPTForCausalLM(OPTConfig()).eval()
>>>
- >>> # Create training and reference policies
- >>> policy_train = TransformersWrapper(
- ... model,
- ... tokenizer=tokenizer,
- ... generate=False,
- ... from_text=True,
- ... chat_template_name="qwen",
- ... )
- >>> policy_ref = TransformersWrapper(
+ >>> # Create reference model
+ >>> ref_model = TransformersWrapper(
... model,
... tokenizer=tokenizer,
+ ... input_mode="history",
... generate=False,
- ... from_text=True,
... return_log_probs=True,
- ... chat_template_name="qwen",
+ ... pad_output=True,
... )
>>>
>>> # Create the RetrieveLogProb transform
>>> transform = RetrieveLogProb(
- ... policy_ref,
+ ... ref_model,
... assistant_only=True,
- ... tokenizer_kwargs={"chat_template_name": "qwen"},
... tokenizer=tokenizer,
... )
>>>
- >>> # Prepare data
- >>> text = history[:, :-1].apply_chat_template(
- ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=True
- ... )
- >>> text_response = history.apply_chat_template(
- ... tokenizer=tokenizer, chat_template_name="qwen", add_generation_prompt=False
- ... )
- >>> text_response = [
- ... txt[len(txt_start):] for txt, txt_start in zip(text_response, text)
- ... ]
- >>> td = TensorDict(
- ... text=text,
- ... text_response=text_response,
- ... history=history,
- ... next=TensorDict(
- ... reward=torch.randn(2, 1),
- ... done=torch.zeros(2, dtype=torch.bool),
- ... history=history,
- ... ),
- ... batch_size=(2,),
- ... )
- >>> data = lazy_stack(list(td.unbind(0)))
+ >>> # Prepare data using ChatHistory
+ >>> chat_history = ChatHistory(full=history)
+ >>> data = TensorDict(history=chat_history, batch_size=(2,))
>>>
>>> # Apply the transform to get reference log probabilities
- >>> data = transform(data)
- >>> # You can get a padded tensor for batching:
- >>> ref_log_probs = data.get(("next", "ref_log_prob"), as_padded_tensor=True)
- >>> print(f"Type: {type(ref_log_probs)}, Length: {len(ref_log_probs)}")
- Type: , Length: 2
- >>> print(f"Example shapes: {[x.shape for x in ref_log_probs]}")
- Example shapes: [torch.Size([35]), torch.Size([35])]
- >>> print(ref_log_probs.shape) # (batch, max_seq_len)
- torch.Size([2, 35])
- >>>
- >>> # Use with SFTLoss for KL regularization
- >>> loss = SFTLoss(
- ... actor_network=policy_train,
- ... tokenizer=tokenizer,
- ... reduction="mean",
- ... normalize_by_seq_length=True,
- ... kl_to_ref_coeff=0.1,
- ... tokenizer_kwargs={"chat_template_name": "qwen"},
- ... )
- >>> loss_vals = loss(data)
- >>> print(f"SFT Loss: {loss_vals.loss_sft.item():.4f}")
- SFT Loss: 10.7856
- >>> print(f"KL to Reference Loss: {loss_vals.loss_kl_to_ref.item():.4f}")
- KL to Reference Loss: 0.0000
- >>> print(f"Total Loss: {loss_vals.sum(reduce=True).item():.4f}")
- Total Loss: 10.7856
+ >>> result = transform(data)
+ >>> log_probs_key = (ref_model.log_probs_key, "full")
+ >>> ref_log_probs = result.get(log_probs_key)
+ >>> print(f"Log-probs shape: {ref_log_probs.shape}")
+ Log-probs shape: torch.Size([2, 26])
- Note:
+ .. note::
By default, the log-probabilities are stored as a list of tensors (one per sample, with variable length).
Use `as_padded_tensor=True` in `.get()` to obtain a batchable tensor (with padding).
The reference log probabilities are computed only for assistant tokens when `assistant_only=True`.
+ **Input Mode Compatibility:**
+ - When `assistant_only=True` (default), the model must have `input_mode='history'` to properly identify assistant tokens.
+ - When `assistant_only=False`, the transform works with any input mode (`"history"`, `"text"`, or `"tokens"`).
+ - This design ensures users are conscious of the limitation that assistant token identification requires structured conversation history.
+
+ .. seealso::
+ :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL`: A higher-level transform that combines two `RetrieveLogProb` instances with `KLComputation`.
+ :class:`~torchrl.envs.llm.transforms.kl.KLComputation`: A transform that computes KL divergence between two log-prob tensors.
+ :class:`~torchrl.envs.llm.transforms.kl.KLRewardTransform`: A legacy transform for KL reward computation (use `RetrieveKL` instead).
"""
def __init__(
self,
- actor: CategoricalSequential,
+ model: LLMWrapperBase,
*,
- history_key: NestedKey | None = None,
- log_prob_key: NestedKey = "ref_log_prob",
- assistant_only: bool = False,
+ log_probs_full_key: NestedKey | None = None,
+ assistant_only: bool = True,
tokenizer_kwargs: dict | None = None,
detach: bool = True,
device: torch.device | None = None,
tokenizer: transformers.AutoTokenizer | None = None,
+ padding_side: str = "left",
):
- if history_key is None:
- history_key = "history"
- self.history_key = history_key
- self.log_prob_key = log_prob_key
- super().__init__(in_keys=[history_key], out_keys=[log_prob_key])
- self.actor = actor
- if not getattr(actor, "return_log_probs", True):
- raise ValueError(
- "The actor must have `return_log_probs=True` to use the `AssistantLogProb` transform."
- )
- if getattr(actor, "generate", True):
- raise ValueError(
- "The actor must have `generate=False` to use the `AssistantLogProb` transform."
- )
- if not getattr(actor, "from_text", False):
- raise ValueError(
- "The actor must have `from_text=True` to use the `AssistantLogProb` transform. If `from_text=False` is required, please file an issue on GitHub."
+ # Set up keys
+ if log_probs_full_key is None:
+ log_probs_full_key = (model.log_probs_key, "full")
+ elif (
+ not isinstance(log_probs_full_key, tuple)
+ or log_probs_full_key[-1] != "full"
+ ):
+ warnings.warn(
+ f"The log_probs_full_key {log_probs_full_key} is not a tuple or does not end with 'full'. "
+ "This may cause issues with the KL computation. "
+ "Please use a tuple with the log_probs_key and 'full' as the last element."
)
- # if getattr(self.actor, "tokenizer_kwargs", {}).get("add_generation_prompt", True):
- # raise ValueError("The actor must have `tokenizer_kwargs['add_generation_prompt']=False` to use the `AssistantLogProb` transform.")
+ self.log_probs_full_key = log_probs_full_key
+
+ # Set up input/output keys
+ in_keys = list(model.in_keys)
+ out_keys = [self.log_probs_full_key]
+ super().__init__(in_keys=in_keys, out_keys=out_keys)
+
+ # Store model and configuration
+ self.model = model
self.assistant_only = assistant_only
+ self.detach = detach
+ self.device = device
+ self.tokenizer = tokenizer
+ self.padding_side = padding_side
+
+ # Set up tokenizer kwargs
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
tokenizer_kwargs.setdefault("tokenize", True)
- tokenizer_kwargs.setdefault("return_tensors", "pt")
+ tokenizer_kwargs.setdefault("return_dict", True)
tokenizer_kwargs.setdefault("padding", False)
tokenizer_kwargs.setdefault("add_generation_prompt", False)
self.tokenizer_kwargs = tokenizer_kwargs
- self.tokenizer = tokenizer
- self.detach = detach
- self.device = device
+
+ # Validate model configuration (after setting assistant_only)
+ self._validate_model_config(model)
+
+ def _validate_model_config(self, model: LLMWrapperBase):
+ """Validate model configuration."""
+ if not getattr(model, "return_log_probs", True):
+ raise ValueError(
+ "The model must have `return_log_probs=True` to use the `RetrieveLogProb` transform."
+ )
+ if getattr(model, "generate", True):
+ raise ValueError(
+ "The model must have `generate=False` to use the `RetrieveLogProb` transform."
+ )
+
+ # Check input mode compatibility with assistant_only
+ input_mode = getattr(model, "input_mode", "history")
+ if self.assistant_only and input_mode != "history":
+ raise ValueError(
+ f"The model must have `input_mode='history'` when `assistant_only=True`. "
+ f"Current input_mode is '{input_mode}'. "
+ f"To use input_mode '{input_mode}', set `assistant_only=False`."
+ )
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
- next_td = self._step(tensordict, tensordict.get("next"))
- return tensordict.set("next", next_td)
+ next_td = tensordict.get("next")
+ next_is_none = False
+ if next_td is None:
+ next_is_none = True
+ next_td = tensordict
+ output = self._step(tensordict, next_td)
+ if next_is_none:
+ return output
+ return tensordict.set("next", output)
+
+ def _mask_assistant_tokens(
+ self, td: TensorDictBase, lp_key: NestedKey
+ ) -> torch.Tensor:
+ """Mask log-probs to only include assistant tokens.
+
+ Args:
+ td: TensorDict containing the data
+ lp_key: Key for log-probs in the TensorDict
+
+ Returns:
+ Masked log-probs tensor
+ """
+ with torch.device(self.device) if self.device is not None else nullcontext():
+ # Get assistant mask
+ assistant_masks = td.get(("masks", "all_assistant_mask"), as_list=True)
+ log_probs = td.get(lp_key, as_list=True)
+ log_probs = [
+ lp[mask.bool()] for lp, mask in _zip_strict(log_probs, assistant_masks)
+ ]
+ if self.model.pad_output:
+ log_probs = pad_sequence(
+ log_probs,
+ batch_first=True,
+ padding_value=0.0,
+ padding_side=self.padding_side,
+ )
+ else:
+ log_probs = torch.nested.as_nested_tensor(
+ log_probs, layout=self.model.layout
+ )
+ return log_probs
@set_list_to_stack(True)
def _step(
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
) -> TensorDictBase:
- td = next_tensordict.select(self.history_key)
- with torch.device(
- self.device
- ) if self.device is not None else contextlib.nullcontext(), torch.no_grad() if self.detach else contextlib.nullcontext():
- result = self.actor(td.select(self.history_key))
- td.update(result.select(getattr(self.actor, "log_prob_key", "log_probs")))
- td.rename_key_(
- getattr(self.actor, "log_prob_key", "log_probs"), self.log_prob_key
- )
- if torch.cuda.is_available():
- gc.collect()
- torch.cuda.empty_cache()
+ # Compute log-probs using the model
+ # Use tensordict since we want to process the "full" entry
+ ref_td = self.model(tensordict.copy())
+ tmp_log_probs_key = (self.model.log_probs_key, "full")
+
+ # Apply assistant masking if requested
if self.assistant_only:
- with torch.device(
- self.device
- ) if self.device is not None else contextlib.nullcontext():
- # Get assistant mask
- history: History = td.get(self.history_key)
- proc = history.apply_chat_template(
- tokenizer=self.actor.tokenizer
- if self.tokenizer is None
- else self.tokenizer,
- **self.tokenizer_kwargs,
+ log_probs = self._mask_assistant_tokens(ref_td, tmp_log_probs_key)
+ ref_td.set(tmp_log_probs_key, log_probs)
+
+ # Rename and store the log-probs
+ if tmp_log_probs_key != self.log_probs_full_key:
+ ref_td.rename_key_(tmp_log_probs_key, self.log_probs_full_key)
+ next_tensordict.update(ref_td, keys_to_update=(self.log_probs_full_key,))
+
+ return next_tensordict
+
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
+ # Add kl to observation spec
+ observation_spec["kl_penalty"] = Unbounded(
+ device=observation_spec.device,
+ shape=observation_spec.shape,
+ )
+ return observation_spec
+
+
+class RetrieveKL(Compose):
+ """A transform to retrieve the KL divergence between two models' log-probabilities.
+
+ This transform combines two :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb` instances
+ with a :class:`~torchrl.envs.llm.transforms.kl.KLComputation` to compute KL divergence
+ between a generation model and a reference model.
+
+ .. note::
+ Both gen_model and ref_model must use the same pad_output value (True or False), otherwise KL computation will fail.
+
+ Args:
+ gen_model (LLMWrapperBase): the generation model, wrapped in such a way that it does not generate but computes the log-probs.
+ In cases where the transform is used within a :class:`~torchrl.collectors.llm.LLMCollector` run on a remote worker, the
+ policy may not be available ahead of time. In this case, the `gen_model` can be set to `"from_collector"` (default) to retrieve the
+ policy from the collector. See :meth:`~torchrl.modules.llm.policies.LLMWrapperBase.get_new_version` for more details
+ about generating a new version of the policy to gather the log-probs.
+ ref_model (LLMWrapperBase): the reference model, wrapped in such a way that it does not generate but computes the log-probs.
+
+ Keyword Args:
+ assistant_only (bool): whether to only retrieve the log-probs of the assistant tokens (i.e., steps of history
+ where the role is `"assistant"`). Defaults to `True`.
+
+ .. note:: When `assistant_only=True`, both models must have `input_mode='history'` to properly identify assistant tokens.
+ For other input modes (`"text"` or `"tokens"`), set `assistant_only=False`.
+ This ensures users are conscious of the limitation that assistant token identification requires structured conversation history.
+
+ gen_log_prob_full_key (str): the key where the log-probs of the generation model are stored. Defaults to `("log_probs", "full")`.
+ ref_log_prob_full_key (str): the key where the log-probs of the reference model are stored. Defaults to `("ref_log_probs", "full")`.
+ history_key (str): the key where the history is stored. Defaults to `"history"`.
+ tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`.
+ To control the tokenization in the actor, pass the tokenizer kwargs to the actor constructor.
+ Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_tensors": "pt", "padding": True, "add_generation_prompt": False}`.
+ detach (bool): whether to exclude the log-probs from the gradient computation. Defaults to `True`.
+ device (torch.device): the device to use for tensor creation. Defaults to `None`.
+ tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor`.
+ padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`.
+ kl_key (NestedKey): the key where the KL divergence is stored. Defaults to `"kl_penalty"`.
+ add_to_reward (bool): whether to add the KL divergence to the reward. Defaults to `True`.
+ coeff (float): the coefficient for the KL term when adding to reward. Defaults to `1.0`.
+ padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`.
+ **kwargs: additional arguments to pass to the `RetrieveLogProb` transform.
+
+ Examples:
+ >>> from torchrl.data.llm import History
+ >>> from torchrl.modules.llm import TransformersWrapper
+ >>> from torchrl.modules.llm.policies import ChatHistory
+ >>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM
+ >>> from tensordict import TensorDict, set_list_to_stack
+ >>> import torch
+ >>>
+ >>> # Set up list to stack for History
+ >>> set_list_to_stack(True).set()
+ >>>
+ >>> # Create chat data
+ >>> chats = [
+ ... [
+ ... {"role": "system", "content": "You are a helpful assistant."},
+ ... {"role": "user", "content": "Hello, how are you?"},
+ ... {"role": "assistant", "content": "I'm doing well, thank you!"},
+ ... ],
+ ... [
+ ... {"role": "system", "content": "You are a helpful assistant."},
+ ... {"role": "user", "content": "What's the weather like?"},
+ ... {"role": "assistant", "content": "I can't check the weather for you."},
+ ... ],
+ ... ]
+ >>> history = History.from_chats(chats)
+ >>> print(f"Created history with shape: {history.shape}")
+ Created history with shape: torch.Size([2, 3])
+ >>>
+ >>> # Setup tokenizer and model
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
+ >>> tokenizer.pad_token = tokenizer.eos_token
+ >>> model = OPTForCausalLM(OPTConfig()).eval()
+ >>>
+ >>> # Create generation and reference models
+ >>> gen_model = TransformersWrapper(
+ ... model,
+ ... tokenizer=tokenizer,
+ ... input_mode="history",
+ ... generate=False,
+ ... return_log_probs=True,
+ ... pad_output=True,
+ ... log_probs_key="gen_log_probs",
+ ... )
+ >>> ref_model = TransformersWrapper(
+ ... model,
+ ... tokenizer=tokenizer,
+ ... input_mode="history",
+ ... generate=False,
+ ... return_log_probs=True,
+ ... pad_output=True,
+ ... log_probs_key="ref_log_probs",
+ ... )
+ >>>
+ >>> # Create RetrieveKL transform
+ >>> transform = RetrieveKL(
+ ... gen_model=gen_model,
+ ... ref_model=ref_model,
+ ... assistant_only=True,
+ ... tokenizer=tokenizer,
+ ... )
+ >>>
+ >>> # Prepare data with next tensordict using ChatHistory
+ >>> chat_history = ChatHistory(full=history)
+ >>> next_td = TensorDict(history=chat_history, batch_size=(2,))
+ >>> data = TensorDict(history=chat_history, next=next_td, batch_size=(2,))
+ >>>
+ >>> # Apply transform
+ >>> result = transform(data)
+ >>> kl = result["next"].get("kl_penalty")
+ >>> print(f"KL shape: {kl.shape}")
+ KL shape: torch.Size([2, 26])
+
+ Note:
+ **Input Mode Compatibility:**
+ - When `assistant_only=True`, both models must have `input_mode='history'` to properly identify assistant tokens.
+ - When `assistant_only=False`, the transform works with any input mode (`"history"`, `"text"`, or `"tokens"`).
+ - This design ensures users are conscious of the limitation that assistant token identification requires structured conversation history.
+
+ .. seealso::
+ :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb`: The base transform for retrieving log-probabilities from a single model.
+ :class:`~torchrl.envs.llm.transforms.kl.KLComputation`: The transform that computes KL divergence between two log-prob tensors.
+ :class:`~torchrl.envs.llm.transforms.kl.KLRewardTransform`: A legacy transform for KL reward computation (use `RetrieveKL` instead).
+ """
+
+ def __init__(
+ self,
+ gen_model: LLMWrapperBase | Literal["from_collector"] = "from_collector",
+ ref_model: LLMWrapperBase | None = None,
+ *,
+ assistant_only: bool | None = True,
+ history_key: str = "history",
+ tokenizer_kwargs: dict[str, Any] | None = None,
+ detach: bool = True,
+ device: torch.device | None = None,
+ tokenizer: transformers.AutoTokenizer | None = None,
+ padding_side: str = "left",
+ gen_log_probs_full_key: NestedKey = ("log_probs", "full"),
+ ref_log_probs_full_key: NestedKey = ("ref_log_probs", "full"),
+ kl_key: NestedKey = "kl_penalty",
+ add_to_reward: bool = True,
+ coeff: float = 1.0,
+ **kwargs,
+ ):
+ if isinstance(gen_model, str) and gen_model == "from_collector":
+ # Lazy init
+ self._initialized = False
+ self._init_params = {
+ "ref_model": ref_model,
+ "assistant_only": assistant_only,
+ "history_key": history_key,
+ "tokenizer_kwargs": tokenizer_kwargs,
+ "detach": detach,
+ "device": device,
+ "tokenizer": tokenizer,
+ "gen_log_probs_full_key": gen_log_probs_full_key,
+ "ref_log_probs_full_key": ref_log_probs_full_key,
+ "kl_key": kl_key,
+ "add_to_reward": add_to_reward,
+ "coeff": coeff,
+ "padding_side": padding_side,
+ **kwargs,
+ }
+ super().__init__()
+ return
+
+ self._initialized = True
+
+ # Check pad_output consistency if both models are provided
+ if hasattr(gen_model, "pad_output") and hasattr(ref_model, "pad_output"):
+ if gen_model.pad_output != ref_model.pad_output:
+ raise ValueError(
+ f"pad_output mismatch: gen_model.pad_output={gen_model.pad_output}, "
+ f"ref_model.pad_output={ref_model.pad_output}. "
+ "Both models must use the same padding strategy for KL computation."
)
- assistant_masks = proc.get("assistant_masks", as_list=True)
- log_probs = td.get(self.log_prob_key, as_list=True)
- log_probs = [
- lp[mask.bool()]
- for lp, mask in _zip_strict(log_probs, assistant_masks)
- ]
- td = td.set(self.log_prob_key, log_probs)
- return next_tensordict.update(td)
+
+ if not getattr(gen_model, "return_log_probs", True):
+ raise ValueError(
+ "The generation model must have `return_log_probs=True` to use the `RetrieveKL` transform."
+ )
+ elif getattr(gen_model, "generate", False):
+ raise ValueError(
+ "The generation model must have `generate=False` to use the `RetrieveKL` transform."
+ )
+
+ if not getattr(ref_model, "return_log_probs", True):
+ raise ValueError(
+ "The reference model must have `return_log_probs=True` to use the `RetrieveKL` transform."
+ )
+ elif getattr(ref_model, "generate", False):
+ raise ValueError(
+ "The reference model must have `generate=False` to use the `RetrieveKL` transform."
+ )
+ if getattr(gen_model, "log_probs_key", "gen_log_probs") == getattr(
+ ref_model, "log_probs_key", "log_probs"
+ ):
+ raise ValueError(
+ "The generation and reference models must have different `log_prob_key` values to use the `RetrieveKL` transform."
+ )
+ t1 = RetrieveLogProb(
+ gen_model,
+ log_probs_full_key=gen_log_probs_full_key,
+ assistant_only=assistant_only,
+ tokenizer_kwargs=tokenizer_kwargs,
+ detach=detach,
+ device=device,
+ tokenizer=tokenizer,
+ padding_side=padding_side,
+ **kwargs,
+ )
+ t2 = RetrieveLogProb(
+ ref_model,
+ log_probs_full_key=ref_log_probs_full_key,
+ assistant_only=assistant_only,
+ tokenizer_kwargs=tokenizer_kwargs,
+ detach=detach,
+ device=device,
+ tokenizer=tokenizer,
+ padding_side=padding_side,
+ **kwargs,
+ )
+ t3 = KLComputation(
+ gen_log_probs_full_key=gen_log_probs_full_key,
+ ref_log_probs_full_key=ref_log_probs_full_key,
+ kl_key=kl_key,
+ add_to_reward=add_to_reward,
+ coeff=coeff,
+ )
+ super().__init__(t1, t2, t3)
+
+ def _init_deferred(self):
+ torchrl_logger.info("Initializing RetrieveKL transform")
+ container = self.container
+ if container is None:
+ # also logging, since this will be sometimes hidden within the AttributeError
+ torchrl_logger.warning(
+ "The container is not set. Please set the container before calling this method."
+ )
+ raise ValueError(
+ "The container is not set. Please set the container before calling this method."
+ )
+ container.empty_cache()
+ self.empty_cache()
+ collector = self.collector
+ if collector is None:
+ # also logging, since this will be sometimes hidden within the AttributeError
+ torchrl_logger.warning(
+ "The collector is not set. Please set the collector before calling this method."
+ )
+ raise ValueError(
+ "The collector is not set. Please set the collector before calling this method."
+ )
+ ref_model = self._init_params["ref_model"]
+ pad_output = getattr(ref_model, "pad_output", None)
+ gen_log_probs_full_key = self._init_params["gen_log_probs_full_key"]
+ if (
+ not isinstance(gen_log_probs_full_key, tuple)
+ or gen_log_probs_full_key[-1] != "full"
+ ):
+ raise ValueError(
+ f"The gen_log_probs_full_key {gen_log_probs_full_key} is not a tuple or does not end with 'full'. "
+ "This may cause issues with the KL computation. "
+ "Please use a tuple with the log_probs_key and 'full' as the last element."
+ )
+ log_probs_key = gen_log_probs_full_key[:-1]
+ gen_model = collector.policy.get_new_version(
+ generate=False,
+ return_log_probs=True,
+ log_probs_key=log_probs_key,
+ input_mode=ref_model.input_mode,
+ input_key=(ref_model.input_mode, "full"),
+ pad_output=pad_output, # Pass pad_output from ref_model
+ )
+ # Create the transforms manually instead of calling __init__
+ t1 = RetrieveLogProb(
+ gen_model,
+ log_probs_full_key=gen_log_probs_full_key,
+ assistant_only=self._init_params["assistant_only"],
+ tokenizer_kwargs=self._init_params["tokenizer_kwargs"],
+ detach=self._init_params["detach"],
+ device=self._init_params["device"],
+ tokenizer=self._init_params["tokenizer"],
+ padding_side=self._init_params["padding_side"],
+ )
+ ref_log_probs_full_key = self._init_params["ref_log_probs_full_key"]
+ if (
+ not isinstance(ref_log_probs_full_key, tuple)
+ or ref_log_probs_full_key[-1] != "full"
+ ):
+ raise ValueError(
+ f"The ref_log_probs_full_key {ref_log_probs_full_key} is not a tuple or does not end with 'full'. "
+ "This may cause issues with the KL computation. "
+ "Please use a tuple with the log_probs_key and 'full' as the last element."
+ )
+ t2 = RetrieveLogProb(
+ ref_model,
+ log_probs_full_key=ref_log_probs_full_key,
+ assistant_only=self._init_params["assistant_only"],
+ tokenizer_kwargs=self._init_params["tokenizer_kwargs"],
+ detach=self._init_params["detach"],
+ device=self._init_params["device"],
+ tokenizer=self._init_params["tokenizer"],
+ padding_side=self._init_params["padding_side"],
+ )
+ t3 = KLComputation(
+ gen_log_probs_full_key=gen_log_probs_full_key,
+ ref_log_probs_full_key=ref_log_probs_full_key,
+ kl_key=self._init_params["kl_key"],
+ add_to_reward=self._init_params["add_to_reward"],
+ coeff=self._init_params["coeff"],
+ )
+ # Replace the transforms in the Compose
+ self.transforms.extend([t1, t2, t3])
+ del self._init_params
+ self._initialized = True
+ torchrl_logger.info("Successfully initialized")
+
+ def _step(
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
+ ) -> TensorDictBase:
+ if not self._initialized:
+ self._init_deferred()
+ return super()._step(tensordict, next_tensordict)
+
+ def _reset(
+ self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
+ ) -> TensorDictBase:
+ if not self._initialized:
+ self._init_deferred()
+ return super()._reset(tensordict, tensordict_reset)
+
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
+ if not self._initialized:
+ self._init_deferred()
+ return super().forward(tensordict)
+
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
+ if not self._initialized:
+ self._init_deferred()
+ return super().transform_observation_spec(observation_spec)
+
+ def transform_reward_spec(self, reward_spec: Composite) -> Composite:
+ if not self._initialized:
+ self._init_deferred()
+ return super().transform_reward_spec(reward_spec)
+
+ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
+ if not self._initialized:
+ self._init_deferred()
+ return super()._inv_call(tensordict)
+
+ def transform_action_spec(self, action_spec: Composite) -> Composite:
+ if not self._initialized:
+ self._init_deferred()
+ return super().transform_action_spec(action_spec)
+
+ def transform_input_spec(self, input_spec: Composite) -> Composite:
+ if not self._initialized:
+ self._init_deferred()
+ return super().transform_input_spec(input_spec)
+
+ def transform_output_spec(self, output_spec: Composite) -> Composite:
+ if not self._initialized:
+ self._init_deferred()
+ return super().transform_output_spec(output_spec)
+
+ def transform_state_spec(self, state_spec: Composite) -> Composite:
+ if not self._initialized:
+ self._init_deferred()
+ return super().transform_state_spec(state_spec)
+
+
+class KLComputation(Transform):
+ """A transform to compute KL divergence between two log-prob tensors and optionally add it to the reward.
+
+ This transform computes KL divergence between generation and reference log-probabilities
+ and can optionally subtract it from the reward (for KL penalty). It's designed to work
+ with the :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb` and :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` transforms.
+
+ .. note::
+ Both input log-prob tensors must use the same padding strategy (pad_output) for correct KL computation.
+
+ Args:
+ gen_log_probs_full_key (NestedKey): the key where the generation model log-probs are stored.
+ Defaults to `("gen_log_probs", "full")`.
+ ref_log_probs_full_key (NestedKey): the key where the reference model log-probs are stored.
+ Defaults to `("ref_log_probs", "full")`.
+ kl_key (NestedKey): the key where the KL divergence is stored. Defaults to `"kl_penalty"`.
+ add_to_reward (bool): whether to add the KL divergence to the reward. Defaults to `True`.
+ coeff (float): the coefficient for the KL term when adding to reward. Defaults to `1.0`.
+ padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`.
+
+ Examples:
+ >>> from tensordict import TensorDict
+ >>> import torch
+ >>>
+ >>> # Create sample log-probs
+ >>> gen_log_probs = torch.randn(2, 10) # 2 samples, 10 tokens each
+ >>> ref_log_probs = torch.randn(2, 10)
+ >>>
+ >>> # Create data with next tensordict
+ >>> next_td = TensorDict(
+ ... {
+ ... ("gen_log_probs", "full"): gen_log_probs,
+ ... ("ref_log_probs", "full"): ref_log_probs,
+ ... "reward": torch.randn(2, 10, 1),
+ ... },
+ ... batch_size=(2,)
+ ... )
+ >>> data = TensorDict(next=next_td, batch_size=(2,))
+ >>>
+ >>> # Create KLComputation transform
+ >>> kl_transform = KLComputation(
+ ... gen_log_probs_key=("gen_log_probs", "full"),
+ ... ref_log_probs_key=("ref_log_probs", "full"),
+ ... kl_key="kl_penalty",
+ ... add_to_reward=True,
+ ... coef=1.0,
+ ... )
+ >>>
+ >>> # Apply transform
+ >>> result = kl_transform(data)
+ >>> kl = result["next"].get("kl_penalty")
+ >>> print(f"KL shape: {kl.shape}")
+ KL shape: torch.Size([2, 10])
+
+ .. seealso::
+ :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb`: The base transform for retrieving log-probabilities from a single model.
+ :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL`: A higher-level transform that combines two `RetrieveLogProb` instances with `KLComputation`.
+ :class:`~torchrl.envs.llm.transforms.kl.KLRewardTransform`: A legacy transform for KL reward computation (use `RetrieveKL` instead).
+
+ """
+
+ def __init__(
+ self,
+ gen_log_probs_full_key: NestedKey = ("log_probs", "full"),
+ ref_log_probs_full_key: NestedKey = ("ref_log_probs", "full"),
+ *,
+ kl_key: NestedKey = "kl_penalty",
+ add_to_reward: bool = True,
+ coeff: float = 1.0,
+ padding_side: str = "left",
+ ):
+ in_keys = [gen_log_probs_full_key, ref_log_probs_full_key]
+ if add_to_reward:
+ in_keys.append("reward")
+ out_keys = [kl_key]
+ if add_to_reward:
+ out_keys.append("reward")
+ super().__init__(in_keys=in_keys, out_keys=out_keys)
+
+ self.gen_log_probs_full_key = gen_log_probs_full_key
+ self.ref_log_probs_full_key = ref_log_probs_full_key
+ self.kl_key = kl_key
+ self.add_to_reward = add_to_reward
+ self.coeff = coeff
+ self.padding_side = padding_side
+
+ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
+ next_td = tensordict.get("next")
+ has_next_td = True
+ if next_td is None:
+ next_td = tensordict
+ has_next_td = False
+ next_td = self._step(tensordict, next_td)
+ if has_next_td:
+ return tensordict.set("next", next_td)
+ return next_td
+
+ def _step(
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
+ ) -> TensorDictBase:
+ # Get log-probs
+ gen_log_probs = next_tensordict.get(self.gen_log_probs_full_key, as_list=True)
+ ref_log_probs = next_tensordict.get(self.ref_log_probs_full_key, as_list=True)
+
+ if gen_log_probs is None or ref_log_probs is None:
+ raise ValueError(
+ f"Log-probs not found. Expected keys: {self.gen_log_probs_key}, {self.ref_log_probs_key}"
+ )
+
+ # Debug: Check lengths and shapes
+ if len(gen_log_probs) != len(ref_log_probs):
+ raise ValueError(
+ f"Batch size mismatch: gen_log_probs has {len(gen_log_probs)} samples, ref_log_probs has {len(ref_log_probs)} samples"
+ )
+
+ # Check individual sequence lengths
+ for i, (gen_lp, ref_lp) in enumerate(_zip_strict(gen_log_probs, ref_log_probs)):
+ if gen_lp.shape != ref_lp.shape:
+ raise ValueError(
+ f"Sample {i} has different shapes: gen_log_probs[{i}].shape={gen_lp.shape}, ref_log_probs[{i}].shape={ref_lp.shape}"
+ )
+
+ # Compute KL divergence: KL(p||q) = E_p[log p - log q]
+ # Here gen_log_probs = log p, ref_log_probs = log q
+ kl = [
+ gen_lp - ref_lp
+ for gen_lp, ref_lp in _zip_strict(gen_log_probs, ref_log_probs)
+ ]
+
+ kl = torch.nested.as_nested_tensor(kl, layout=torch.strided)
+
+ next_tensordict.set(self.kl_key, kl)
+
+ # Add to reward if requested
+ if self.add_to_reward:
+ reward = next_tensordict.get("reward", as_list=True)
+ if reward is not None:
+ if isinstance(reward, list):
+ if reward[0].ndim != kl[0].ndim + 1:
+ raise ValueError(
+ f"The rewards have shape {reward[0].shape} but the kl has shape {kl[0].shape}. "
+ f"The rewards should have one more dimension than the KL."
+ )
+ reward = [
+ r - self.coeff * k.unsqueeze(-1)
+ for r, k in _zip_strict(reward, kl)
+ ]
+ next_tensordict.set(
+ "reward",
+ torch.nested.as_nested_tensor(reward, layout=torch.strided),
+ )
+ else:
+ if reward.ndim != kl.ndim + 1:
+ raise ValueError(
+ f"The rewards have shape {reward.shape} but the kl has shape {kl.shape}. "
+ f"The rewards should have one more dimension than the KL."
+ )
+ reward = reward - self.coeff * kl.unsqueeze(-1)
+ next_tensordict.set("reward", reward)
+
+ return next_tensordict
+
+ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
+ # Add kl to observation spec
+ observation_spec[self.kl_key] = Unbounded(
+ device=observation_spec.device,
+ shape=observation_spec.shape,
+ )
+ return observation_spec
+
+ def transform_reward_spec(self, reward_spec: Composite) -> Composite:
+ # Optionally adjust reward spec if KL is added to reward
+ if self.add_to_reward:
+ shape = reward_spec["reward"].shape
+ # For LLMs, the shape of the reward is (batch, -1, 1)
+ shape = (*shape, -1, 1)
+ reward_spec["reward"] = reward_spec["reward"].clone()
+ reward_spec["reward"].shape = torch.Size(shape)
+ return reward_spec
diff --git a/torchrl/envs/llm/transforms/policy_version.py b/torchrl/envs/llm/transforms/policy_version.py
index 711326be410..493b630780c 100644
--- a/torchrl/envs/llm/transforms/policy_version.py
+++ b/torchrl/envs/llm/transforms/policy_version.py
@@ -178,10 +178,12 @@ def transform_observation_spec(self, spec: Composite) -> Composite:
"""
if self.version_type in (str, "uuid"):
spec["policy_version"] = NonTensor(
- example_data=uuid.uuid4(), shape=spec.shape
+ example_data=uuid.uuid4(), shape=spec.shape, device=spec.device
)
elif self.version_type in (int, "int"):
- spec["policy_version"] = Unbounded(shape=spec.shape, dtype=torch.int64)
+ spec["policy_version"] = Unbounded(
+ shape=spec.shape, dtype=torch.int64, device=spec.device
+ )
else:
raise ValueError(f"Invalid version type: {self.version_type}")
return spec
diff --git a/torchrl/envs/llm/transforms/reason.py b/torchrl/envs/llm/transforms/reason.py
index 6890d45b80e..fad26cfa689 100644
--- a/torchrl/envs/llm/transforms/reason.py
+++ b/torchrl/envs/llm/transforms/reason.py
@@ -9,8 +9,9 @@
from typing import Callable, Literal
from tensordict import lazy_stack, TensorDictBase
+from torchrl._utils import logger as torchrl_logger
-from torchrl.data.llm.chat import History
+from torchrl.data.llm.history import History
from torchrl.envs import Transform
from torchrl.envs.common import EnvBase
@@ -161,12 +162,24 @@ def _step(
next_tensordict.update(lazy_stack(ntds))
return next_tensordict
+ # Check that base_env is on history mode
+ parent = self.parent
+ if parent is None:
+ raise RuntimeError("AddThinkingPrompt must be used with a ChatEnv")
+ base_env = parent.base_env
+ if base_env.input_mode != "history":
+ raise RuntimeError(
+ "AddThinkingPrompt must be used with a ChatEnv in history mode"
+ )
+
# Check if we should add the thinking prompt
if self.cond(next_tensordict):
- history: History = next_tensordict["history"]
+ torchrl_logger.info("Adding thinking prompt.")
+ history: History = next_tensordict["history"].prompt
last_turn = history[..., -1]
if self.edit_last_turn:
+
# Edit the last assistant response
content = last_turn.content
modified_content = self._replace_answer_with_prompt(content)
@@ -181,14 +194,14 @@ def _step(
# Replace the last turn in history
history = history[..., :-1].append(new_turn)
- next_tensordict["history"] = history
+ next_tensordict["history"].prompt = history
else:
# Add a new message
prompt = self.prompt
history = history.append(History(role=self.role, content=prompt))
- next_tensordict["history"] = history
+ next_tensordict["history"].prompt = history
if self.undo_done:
parent: EnvBase = self.parent
@@ -208,38 +221,66 @@ def _step(
reward = next_tensordict.get(key)
if reward is not None:
next_tensordict.set(key, reward.zero_())
+ else:
+ torchrl_logger.info("Not adding thinking prompt.")
return next_tensordict
def _replace_answer_with_prompt(self, content: str) -> str:
- """Replace the answer section with a thinking prompt.
+ """Replace the last answer section with a thinking prompt.
- This method uses regex to find and replace the ... section
+ This method uses regex to find and replace the last ... section
with the thinking prompt, preserving any content before the answer tag.
+ Only the last answer block is replaced to avoid interfering with earlier
+ examples or instructions that might contain answer tags.
Args:
content: The original content string
Returns:
- The modified content with the answer replaced by the thinking prompt
+ The modified content with the last answer replaced by the thinking prompt
"""
# Pattern to match ... with optional EOS token
+ # Use non-greedy matching and be more specific about the end
answer_pattern = r".*? (?:\s*<\|im_end\|>)?"
# Check if there's an answer tag
if "" in content:
- # Replace the answer section with the thinking prompt
- prompt = self.prompt
+ # Find all matches to get the last one
+ matches = list(re.finditer(answer_pattern, content, flags=re.DOTALL))
- # Replace the answer section
- modified_content = re.sub(answer_pattern, prompt, content, flags=re.DOTALL)
+ if matches:
+ # Get the last match
+ last_match = matches[-1]
+ start, end = last_match.span()
- # Clean up any trailing whitespace
- modified_content = modified_content.rstrip()
+ # Replace only the last answer section with the thinking prompt
+ prompt = self.prompt
+ modified_content = content[:start] + prompt + content[end:]
+
+ # Clean up any trailing whitespace
+ modified_content = modified_content.rstrip()
+
+ # Ensure we end with the EOS token if the original content had it
+ if content.endswith("<|im_end|>"):
+ modified_content = modified_content.rstrip() + "<|im_end|>"
+
+ # Ensure proper spacing around the prompt
+ if not modified_content.endswith(prompt):
+ # If the prompt wasn't properly inserted, append it
+ modified_content = content.rstrip()
+ if modified_content.endswith("<|im_end|>"):
+ modified_content = modified_content[
+ : -len("<|im_end|>")
+ ].rstrip()
+ modified_content = modified_content + "\n\n" + prompt + "<|im_end|>"
+ else:
+ # No matches found, just append the prompt
+ prompt = self.prompt
+ modified_content = content.rstrip() + "\n\n" + prompt
else:
# No answer tag found, just append the prompt
prompt = self.prompt
-
modified_content = content.rstrip() + "\n\n" + prompt
return modified_content
diff --git a/torchrl/envs/llm/transforms/tools.py b/torchrl/envs/llm/transforms/tools.py
index 3492cbb69b5..a9fd7e28434 100644
--- a/torchrl/envs/llm/transforms/tools.py
+++ b/torchrl/envs/llm/transforms/tools.py
@@ -508,11 +508,15 @@ def _process_llm_response(self, response: str, i: int) -> list[str]:
return results
- def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
+ def _step(
+ self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
+ ) -> TensorDictBase:
if next_tensordict.batch_dims > 1:
- with next_tensordict.view(-1) as next_tensordict_flat:
+ with next_tensordict.view(-1) as next_tensordict_flat, tensordict.view(
+ -1
+ ) as tensordict_flat:
# Call the transform on the flattened tensordict
- next_tensordict_flat = self._call(next_tensordict_flat)
+ next_tensordict_flat = self._step(tensordict_flat, next_tensordict_flat)
return next_tensordict
# Ensure we have enough processes for the batch
@@ -520,7 +524,7 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
self._ensure_processes(len(next_tensordict))
# Convert text to a history
- history = next_tensordict["history"]
+ history = next_tensordict["history"].prompt
# Isolate last element, which should be our action
local_history = history[..., -1]
@@ -555,7 +559,7 @@ def fill_procs(proc: list[History], max_len: int) -> list[History]:
# Procs has the shape of the batch-size. We can cat along dim=-1
procs = lazy_stack([lazy_stack(p) for p in procs])
history.extend(procs, dim=-1)
- next_tensordict["history"] = history
+ next_tensordict["history"].prompt = history
return next_tensordict
def __del__(self):
@@ -765,8 +769,18 @@ def _call(self, next_tensordict: TensorDictBase) -> TensorDictBase:
next_tensordict_flat = self._call(next_tensordict_flat)
return next_tensordict
+ # Check that base_env is on history mode
+ parent = self.parent
+ if parent is None:
+ raise RuntimeError("MCPToolTransform must be used with a ChatEnv")
+ base_env = parent.base_env
+ if base_env.input_mode != "history":
+ raise RuntimeError(
+ "MCPToolTransform must be used with a ChatEnv in history mode"
+ )
+
# Convert text to a history
- history = next_tensordict["history"]
+ history = next_tensordict["history"].prompt
# Isolate last element, which should be our action
local_history = history[..., -1]
@@ -801,7 +815,7 @@ def fill_procs(proc: list[History], max_len: int) -> list[History]:
# Procs has the shape of the batch-size. We can cat along dim=-1
procs = lazy_stack([lazy_stack(p) for p in procs])
history.extend(procs, dim=-1)
- next_tensordict["history"] = history
+ next_tensordict["history"].prompt = history
return next_tensordict
def _reset(
diff --git a/torchrl/envs/transforms/rb_transforms.py b/torchrl/envs/transforms/rb_transforms.py
index 8507ce6d8f3..a2ff42522d2 100644
--- a/torchrl/envs/transforms/rb_transforms.py
+++ b/torchrl/envs/transforms/rb_transforms.py
@@ -192,7 +192,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
self._validate()
total_cat = self._append_tensordict(tensordict)
- if total_cat.shape[-1] >= self.n_steps:
+ if total_cat.shape[-1] > self.n_steps:
out = _multi_step_func(
total_cat,
done_key=self.done_key,
diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py
index a00747d7e02..6857d74545a 100644
--- a/torchrl/envs/transforms/transforms.py
+++ b/torchrl/envs/transforms/transforms.py
@@ -308,6 +308,19 @@ def out_keys_inv(self, value):
value = [unravel_key(val) for val in value]
self._out_keys_inv = value
+ @property
+ def collector(self) -> DataCollectorBase | None: # noqa: F821 # type: ignore
+ """Returns the collector associated with the container, if it exists.
+
+ This can be used whenever the transform needs to be made aware of the collector or the policy associated with it.
+
+ Make sure to call this property only on transforms that are not nested in sub-processes.
+ The collector reference will not be passed to the workers of a :class:`~torchrl.envs.ParallelEnv` or
+ similar batched environments.
+
+ """
+ return self.container.collector
+
def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
@@ -687,7 +700,7 @@ def clone(self) -> T:
return self_copy
@property
- def container(self):
+ def container(self) -> EnvBase | None:
"""Returns the env containing the transform.
Examples:
@@ -952,6 +965,13 @@ def add_truncated_keys(self) -> TransformedEnv:
self.empty_cache()
return self
+ # def _post_step_mdp_hooks(self, tensordict: TensorDictBase) -> TensorDictBase:
+ # """Allows modification of the tensordict after the step_mdp."""
+ # if type(self.base_env)._post_step_mdp_hooks is not None:
+ # If the base env has a _post_step_mdp_hooks, we call it
+ # tensordict = self.base_env._post_step_mdp_hooks(tensordict)
+ # return tensordict
+
def _set_env(self, env: EnvBase, device) -> None:
if device != env.device:
env = env.to(device)
@@ -1178,6 +1198,7 @@ def _reset(self, tensordict: TensorDictBase | None = None, **kwargs):
if tensordict is not None:
# We must avoid modifying the original tensordict so a shallow copy is necessary.
# We just select the input data and reset signal, which is all we need.
+ self.transform.transform_input_spec(self.base_env.input_spec.unlock_())
tensordict = tensordict.select(
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
)
@@ -6502,13 +6523,16 @@ def _reset_func(
if self.single_default_value and callable(self.default_value):
if not _reset.all():
# FIXME: use masked op
- tensordict_reset = tensordict_reset.clone()
+ # tensordict_reset = tensordict_reset.clone()
reset_val = self.default_value(reset=_reset)
- # This is safe because env.reset calls _update_during_reset which will discard the new data
- tensordict_reset = (
- self.container.full_observation_spec.zero().select(
- *reset_val.keys(True)
- )
+ # This is safE because env.reset calls _update_during_reset which will discard the new data
+ # tensordict_reset = (
+ # self.container.full_observation_spec.zero().select(
+ # *reset_val.keys(True)
+ # )
+ # )
+ tensordict_reset = reset_val.new_zeros(
+ _reset.shape, empty_lazy=True
)
tensordict_reset[_reset] = reset_val
else:
diff --git a/torchrl/envs/transforms/utils.py b/torchrl/envs/transforms/utils.py
index 8ef96c04ce0..b5302329bb2 100644
--- a/torchrl/envs/transforms/utils.py
+++ b/torchrl/envs/transforms/utils.py
@@ -24,7 +24,31 @@ def new_fun(self, *args, **kwargs):
class _set_missing_tolerance:
- """Context manager to change the transform tolerance to missing values."""
+ """Context manager to change the transform tolerance to missing values.
+
+ If a transform has a missing_tolerance of True, it will not raise an error if a key is missing during reset.
+
+ This is implemented via :meth:`~torchrl.envs.transforms.Transform.set_missing_tolerance`.
+
+ The way this is handled is that, if `_reset` calls the default `_call` method, it will not raise an error if an input key is missing.
+
+ For custom `_reset` methods, you should implement this yourself:
+
+ Exmples:
+ >>> def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
+ ... with _set_missing_tolerance(self, True):
+ ... tensordict_reset = self.foo(tensordict, tensordict_reset)
+ ... return tensordict_reset
+ >>> def foo(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
+ ... if self.input_keys[0] not in tensordict_reset and self.missing_tolerance:
+ ... return tensordict_reset
+ ... else:
+ ... # your code here
+
+ Because `missing_tolerance` will be turned off during calls to `_step`, you can be sure that an appropriate KeyError will be raised
+ if the input key is missing at that time.
+
+ """
def __init__(self, transform, mode):
self.transform = transform
diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py
index 0346a25935e..81485ff8e4e 100644
--- a/torchrl/envs/utils.py
+++ b/torchrl/envs/utils.py
@@ -93,11 +93,11 @@ def __init__(
exclude_done: bool = False,
exclude_action: bool = True,
):
- action_keys = env.action_keys
- done_keys = env.done_keys
- reward_keys = env.reward_keys
- observation_keys = env.full_observation_spec.keys(True, True)
- state_keys = env.full_state_spec.keys(True, True)
+ action_keys = env._action_keys_step_mdp
+ done_keys = env._done_keys_step_mdp
+ reward_keys = env._reward_keys_step_mdp
+ observation_keys = env._observation_keys_step_mdp
+ state_keys = env._state_keys_step_mdp
self.action_keys = [unravel_key(key) for key in action_keys]
self.done_keys = [unravel_key(key) for key in done_keys]
self.observation_keys = list(observation_keys)
@@ -245,6 +245,8 @@ def _grab_and_place(
else:
if is_non_tensor(val):
val = val.clone()
+ if is_tensor_collection(val):
+ val = val.copy()
data_out._set_str(
key, val, validated=True, inplace=False, non_blocking=False
)
@@ -957,6 +959,7 @@ def make_shape(shape):
# Assume all the non-tensors have the same datatype
example_data=tensor.view(-1)[0].data,
device=tensor.device,
+ feature_dims=len(tensor.shape) - len(data.shape),
)
if is_non_tensor(tensor)
else Unbounded(
@@ -1463,7 +1466,9 @@ def _update_during_reset(
reset = reset.any(-1)
reset = reset.reshape(node.shape)
# node.update(node.where(~reset, other=node_reset, pad=0))
- node.where(~reset, other=node_reset, out=node, pad=0)
+ node.where(
+ ~reset, other=node_reset, out=node, pad=0, update_batch_size=True
+ )
# node = node.clone()
# idx = reset.nonzero(as_tuple=True)[0]
# node[idx].update(node_reset[idx])
diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py
index e80d5b427dc..a349aba6635 100644
--- a/torchrl/modules/__init__.py
+++ b/torchrl/modules/__init__.py
@@ -140,6 +140,7 @@
"MaskedOneHotCategorical",
"MultiAgentConvNet",
"MultiAgentMLP",
+ "LLMMaskedCategorical",
"MultiAgentNetBase",
"MultiStepActorWrapper",
"NoisyLazyLinear",
diff --git a/torchrl/modules/distributions/__init__.py b/torchrl/modules/distributions/__init__.py
index 17d7bef7085..1102637e26c 100644
--- a/torchrl/modules/distributions/__init__.py
+++ b/torchrl/modules/distributions/__init__.py
@@ -58,6 +58,7 @@
"distributions",
"Delta",
"IndependentNormal",
+ "LLMMaskedCategorical",
"NormalParamWrapper",
"TanhDelta",
"TanhNormal",
diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py
index 930b90ddbba..068ca34bb2f 100644
--- a/torchrl/modules/distributions/discrete.py
+++ b/torchrl/modules/distributions/discrete.py
@@ -823,6 +823,7 @@ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
# For token-level masking, we need to check if specific tokens are masked
logits = self._original_logits
+ value = value.masked_fill(~self._mask, self.ignore_index)
if value.ndim > 1:
# Reshape for cross_entropy: (batch, seq_len, vocab) -> (batch*seq_len, vocab)
logits_flat = logits.reshape(-1, logits.size(-1))
diff --git a/torchrl/modules/llm/__init__.py b/torchrl/modules/llm/__init__.py
index 735715866ff..3ec911506ca 100644
--- a/torchrl/modules/llm/__init__.py
+++ b/torchrl/modules/llm/__init__.py
@@ -11,14 +11,28 @@
vLLMWorker,
)
-from .policies import CategoricalSequential, TransformersWrapper, vLLMWrapper
+from .policies import (
+ ChatHistory,
+ LLMWrapperBase,
+ LogProbs,
+ Masks,
+ Text,
+ Tokens,
+ TransformersWrapper,
+ vLLMWrapper,
+)
__all__ = [
- "CategoricalSequential",
+ "LLMWrapperBase",
"LLMOnDevice",
"TransformersWrapper",
"make_vllm_worker",
+ "ChatHistory",
"stateless_init_process_group",
"vLLMWorker",
"vLLMWrapper",
+ "Text",
+ "LogProbs",
+ "Masks",
+ "Tokens",
]
diff --git a/torchrl/modules/llm/policies/__init__.py b/torchrl/modules/llm/policies/__init__.py
index e91ec9901cf..1bdf27e0db1 100644
--- a/torchrl/modules/llm/policies/__init__.py
+++ b/torchrl/modules/llm/policies/__init__.py
@@ -5,9 +5,18 @@
from __future__ import annotations
-from .common import CategoricalSequential
+from .common import ChatHistory, LLMWrapperBase, LogProbs, Masks, Text, Tokens
from .transformers_wrapper import TransformersWrapper
from .vllm_wrapper import vLLMWrapper
-__all__ = ["TransformersWrapper", "vLLMWrapper", "CategoricalSequential"]
+__all__ = [
+ "TransformersWrapper",
+ "vLLMWrapper",
+ "LLMWrapperBase",
+ "Text",
+ "LogProbs",
+ "Masks",
+ "Tokens",
+ "ChatHistory",
+]
diff --git a/torchrl/modules/llm/policies/common.py b/torchrl/modules/llm/policies/common.py
index 2021638406c..1f42c889391 100644
--- a/torchrl/modules/llm/policies/common.py
+++ b/torchrl/modules/llm/policies/common.py
@@ -4,63 +4,835 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
+import weakref
+from typing import Any, Literal, overload
+
import torch
from tensordict import NestedKey, TensorDictBase
from tensordict.nn import TensorDictModuleBase, TensorDictSequential
+from tensordict.tensorclass import TensorClass
+from tensordict.utils import _zip_strict
from torch import distributions as D
from torch.distributions import Categorical
-from torchrl.modules import MaskedCategorical
+from torch.nn.utils.rnn import pad_sequence
+from torchrl.data.llm import History
+from torchrl.data.tensor_specs import Unbounded
+from torchrl.modules.distributions.discrete import LLMMaskedCategorical
+
+# TODOs:
+# - [ ] Remove the useless view(-1) calls when num_samples is not > 1
+# - [ ] Remove as_list=True and use a context manager to handle that
+# - [ ] Make sure tensordict can handle nested lazy tds that have a get(key, as_list=True) - I think it breaks atm
+# - [ ] Handle packing
+
+
+class Tokens(TensorClass["nocast"]):
+ """A Tokens container.
+
+ Args:
+ prompt (torch.Tensor | None): The prompt tokens.
+ response (torch.Tensor | None): The response tokens.
+ assistant (torch.Tensor | None): The assistant tokens.
+ full (torch.Tensor | None): The tokens across prompt and response.
+ padded (bool | None): Whether the tokens are padded.
+
+ Shapes:
+ - prompt: (batch_size, prompt_length). If padded, padded on the left.
+ - response: (batch_size, response_length). If padded, padded on the right.
+ - full: (batch_size, prompt_length + response_length). If padded, padded on the left and/or right.
+ - padded: bool.
+
+ """
+
+ prompt: torch.Tensor | None = None
+ response: torch.Tensor | None = None
+ full: torch.Tensor | None = None
+ padded: bool | None = None
+
+ @classmethod
+ def default_spec(
+ cls,
+ shape=(-1,),
+ keys: list[Literal["prompt", "response", "full"]] | None = None,
+ ):
+ """A default spec to use in transforms / envs that return Tokens objects."""
+ from torchrl.data import Composite, NonTensor
+
+ if keys is None:
+ keys = ["prompt", "response", "full"]
+
+ defaults = {k: Unbounded(shape=shape + (-1,)) for k in keys}
+ defaults["padded"] = NonTensor(shape=shape, example_data=False)
+ return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
-class CategoricalSequential(TensorDictModuleBase):
- """A ProbabilisticTensorDictSequential subclass meant to work with LLMs.
- .. seealso:: :class:`~tensordict.nn.ProbabilisticTensorDictSequential` class.
+class Masks(TensorClass["nocast"]):
+ """A Masks container.
+
+ Args:
+ all_attention_mask (torch.Tensor | None): The attention mask across all tokens. The attention mask represents
+ the tokens that are not masked. and that the model can attend to.
+ all_assistant_mask (torch.Tensor | None): The assistant mask across all tokens, i.e. the tokens that
+ are produced by the assistant.
+ This is recovered from the the `assistant_masks` output of :meth:`~torchrl.data.llm.History.apply_chat_template`,
+ if the chat template supports it.
+ padded (bool | None): Whether the masks are padded.
+
+ The masks always have the same shape as the `full` tensor in :class:`~torchrl.modules.llm.policies.common.Tokens`,
+ and :class:`~torchrl.modules.llm.policies.common.LogProbs`.
"""
+ all_attention_mask: torch.Tensor | None = None
+ all_assistant_mask: torch.Tensor | None = None
+ padded: bool | None = None
+
+ @classmethod
+ def default_spec(
+ cls,
+ shape=(-1,),
+ keys: list[Literal["all_attention_mask", "all_assistant_mask"]] | None = None,
+ ):
+ """A default spec to use in transforms / envs that return Masks objects."""
+ from torchrl.data import Composite, NonTensor
+
+ if keys is None:
+ keys = ["all_attention_mask", "all_assistant_mask"]
+
+ defaults = {k: Unbounded(shape=shape + (-1,)) for k in keys}
+ defaults["padded"] = NonTensor(shape=shape, example_data=False)
+
+ return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
+
+
+class ChatHistory(TensorClass["nocast"]):
+ """A chat history container for managing conversation data in LLM environments.
+
+ This class serves as a structured container for chat history data, similar to how
+ :class:`~torchrl.modules.llm.policies.Text` and :class:`~torchrl.modules.llm.policies.Tokens`
+ are used for text and token data respectively.
+
+ **Recent Changes:**
+ - **Modular Design**: ChatHistory is now used consistently across LLM wrappers and environments
+ to represent conversation state in a structured way.
+ - **Integration with Wrappers**: Both vLLMWrapper and TransformersWrapper now use ChatHistory
+ objects when `input_mode="history"` is specified.
+ - **Environment Support**: ChatEnv and related environments use ChatHistory for state management.
+
+ Args:
+ prompt (History | None): The prompt history stack containing the conversation up to the current point.
+ response (History | None): The response history items (typically generated by the LLM).
+ full (History | None): The complete history across prompt and response.
+
+ Example:
+ >>> from torchrl.data.llm import History
+ >>> from torchrl.modules.llm.policies import ChatHistory
+ >>>
+ >>> # Create a conversation history
+ >>> history = History.from_chats([[
+ ... {"role": "user", "content": "Hello"},
+ ... {"role": "assistant", "content": "Hi there!"}
+ ... ]])
+ >>>
+ >>> # Create ChatHistory object for LLM wrapper input
+ >>> chat_history = ChatHistory(prompt=history)
+ >>>
+ >>> # Use with LLM wrapper
+ >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
+ >>> print(result["history"].response) # New response from LLM
+ >>> print(result["history"].full) # Complete conversation
+
+ .. seealso::
+ :class:`~torchrl.modules.llm.policies.Text`: Container for text data.
+ :class:`~torchrl.modules.llm.policies.Tokens`: Container for token data.
+ :class:`~torchrl.data.llm.History`: The underlying History class for conversation data.
+ """
+
+ prompt: History | None = None
+ response: History | None = None
+ full: History | None = None
+
+ @classmethod
+ def default_spec(
+ cls,
+ shape=(-1,),
+ keys: list[Literal["prompt", "response", "full"]] | None = None,
+ ):
+ """A default spec to use in transforms / envs that return ChatHistory objects."""
+ from torchrl.data import Composite
+
+ if keys is None:
+ keys = ["prompt", "response", "full"]
+ return Composite(
+ {k: History.default_spec(shape=shape + (-1,)) for k in keys},
+ shape=shape[:-1],
+ data_cls=cls,
+ step_mdp_static=True,
+ )
+
+
+class LogProbs(TensorClass["nocast"]):
+ """A log-probability container.
+
+ Args:
+ prompt (torch.Tensor | None): The prompt log-probabilities.
+ response (torch.Tensor | None): The response log-probabilities.
+ assistant (torch.Tensor | None): The assistant log-probabilities.
+ full (torch.Tensor | None): The log-probabilities across prompt and response.
+ padded (bool | None): Whether the log-probabilities are padded.
+
+ Shapes:
+ - prompt: (batch_size, prompt_length). If padded, padded on the left.
+ - response: (batch_size, response_length). If padded, padded on the right.
+ - full: (batch_size, prompt_length + response_length). If padded, padded on the left and/or right.
+ - padded: bool.
+
+ """
+
+ prompt: torch.Tensor | None = None
+ response: torch.Tensor | None = None
+ full: torch.Tensor | None = None
+ padded: bool | None = None
+
+ @classmethod
+ def default_spec(
+ cls,
+ shape=(-1,),
+ keys: list[Literal["prompt", "response", "full"]] | None = None,
+ ):
+ """A default spec to use in transforms / envs that return LogProbs objects."""
+ from torchrl.data import Composite, NonTensor
+
+ if keys is None:
+ keys = ["prompt", "response", "full"]
+
+ defaults = {k: Unbounded(shape=shape + (-1,)) for k in keys}
+ defaults["padded"] = NonTensor(shape=shape, example_data=False)
+
+ return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
+
+
+class Text(TensorClass["nocast"]):
+ """A text container.
+
+ Args:
+ prompt (str | None): The prompt text.
+ response (str | None): The response text.
+ full (str | None): The text across prompt and response.
+ """
+
+ prompt: str | None = None
+ response: str | None = None
+ full: str | None = None
+
+ @classmethod
+ def default_spec(
+ cls,
+ shape=(-1,),
+ keys: list[Literal["prompt", "response", "full"]] | None = None,
+ ):
+ """A default spec to use in transforms / envs that return Text objects."""
+ from torchrl.data import Composite, NonTensor
+
+ if keys is None:
+ keys = ["prompt", "response", "full"]
+
+ defaults = {k: NonTensor(shape=shape, example_data="a string") for k in keys}
+
+ return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
+
+
+class LogProbDistribution(D.Distribution):
+ """A distribution that works directly with log-probabilities.
+
+ This is useful when we have pre-computed log-probabilities (e.g., from vLLM)
+ and want to compute log_prob() without having access to the original logits.
+ """
+
+ def __init__(self, log_probs: torch.Tensor, mask: torch.Tensor | None = None):
+ """Initialize with log-probabilities.
+
+ Args:
+ log_probs: Tensor of shape [batch, seq_len] containing log-probabilities
+ mask: Optional mask of shape [batch, seq_len] indicating valid positions
+ """
+ self.log_probs = log_probs
+ self.mask = mask
+ batch_shape = log_probs.shape[:-1] if log_probs.dim() > 1 else log_probs.shape
+ event_shape = log_probs.shape[-1:] if log_probs.dim() > 1 else torch.Size([])
+ super().__init__(batch_shape=batch_shape, event_shape=event_shape)
+
+ def log_prob(self, value: torch.Tensor) -> torch.Tensor:
+ """Compute log-probability for the given tokens.
+
+ Args:
+ value: Tensor of shape [batch, seq_len] containing token indices
+
+ Returns:
+ Tensor of shape [batch, seq_len] containing log-probabilities
+ """
+ # For log-prob distributions, we just return the pre-computed log-probs
+ # at the positions specified by the value tensor
+ if value.shape != self.log_probs.shape:
+ raise ValueError(
+ f"Value shape {value.shape} must match log_probs shape {self.log_probs.shape}"
+ )
+
+ result = self.log_probs.clone()
+
+ # Apply mask if provided
+ if self.mask is not None:
+ result = torch.where(
+ self.mask,
+ result,
+ torch.tensor(0.0, device=result.device, dtype=result.dtype),
+ )
+
+ return result
+
+ def sample(self, sample_shape: tuple | torch.Size | None = None) -> torch.Tensor:
+ """Sample from the distribution.
+
+ Note: This is not implemented for log-prob distributions since we don't have
+ the full probability distribution, only the log-probs for specific tokens.
+ """
+ raise NotImplementedError("Sampling not supported for LogProbDistribution")
+
+ def entropy(self) -> torch.Tensor:
+ """Compute entropy.
+
+ Note: This is not implemented for log-prob distributions since we don't have
+ the full probability distribution.
+ """
+ raise NotImplementedError("Entropy not supported for LogProbDistribution")
+
+
+class LLMWrapperBase(TensorDictModuleBase):
+ r"""A LLM wrapper base class.
+
+ This class provides a consistent interface for LLM wrappers with the following features:
+ - Support for different input modalities (history, text, tokens)
+ - Consistent output structure using TensorClass objects (Text, Tokens, Masks, LogProbs)
+ - Configurable generation and log-probability computation
+
+ Args:
+ model: The underlying model to wrap.
+
+ Keyword Args:
+ tokenizer: The tokenizer to use for encoding and decoding text.
+ input_mode: The input modality to use. Must be one of "history", "text", or "tokens".
+ input_key: The key for the input data. If None, defaults to the input_mode name.
+ attention_mask_key: The key for attention masks (used in "tokens" mode).
+ generate: Whether to enable text generation.
+ generate_kwargs: Additional arguments to pass to the model's generate method.
+ tokenizer_kwargs: Additional arguments to pass to the tokenizer.
+ pad_output: Whether to pad the output sequences to a uniform length.
+ inplace: Determines how the module should handle in-place operations.
+ device: The device to use for computation.
+ layout: The layout to use for the output tensors when pad_output=False.
+ num_samples: The number of samples to generate.
+ log_probs_key (NestedKey | None, optional): The key for the log probabilities :class:`~torchrl.modules.llm.policies.LogProbs` object. Defaults to `"log_probs"`.
+ text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`.
+ tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`.
+ masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`.
+
+ Attributes:
+ collector: The collector associated with the module, if it exists.
+
+ .. seealso::
+ - :class:`~torchrl.modules.llm.policies.TransformersWrapper` (see :ref:`ref_transformers_wrapper`)
+ - :class:`~torchrl.modules.llm.policies.vLLMWrapper` (see :ref:`ref_vllm_wrapper`)
+ """
+
generate: bool
+ pad_output: bool
+ text_key: NestedKey
+ tokens_key: NestedKey
+ masks_key: NestedKey
+ log_probs_key: NestedKey
+ in_keys: list[NestedKey]
+ out_keys: list[NestedKey]
+ inplace: bool
+ device: torch.device | None
+ layout: torch.layout | None
+ num_samples: int | None
+
+ @overload
+ def __init__(
+ self,
+ model: Any | str,
+ *,
+ tokenizer: callable | str | None = None, # type: ignore
+ input_mode: str = "history",
+ input_key: NestedKey | None = None,
+ attention_mask_key: str = "attention_mask",
+ generate: bool = True,
+ generate_kwargs: dict | None = None,
+ tokenizer_kwargs: dict | None = None,
+ pad_output: bool = False,
+ inplace: Literal[True, False, "empty"] | None = None,
+ device: torch.device | None = None,
+ layout: torch.layout | None = None,
+ num_samples: int | None = None,
+ chat_template_name: Literal["chatml_format", "qwen"] | None = None,
+ chat_template: str | None = None,
+ return_log_probs: bool | None = None,
+ history_key: NestedKey | None = "history",
+ text_key: NestedKey | None = "text",
+ tokens_key: NestedKey | None = "tokens",
+ masks_key: NestedKey | None = "masks",
+ log_probs_key: NestedKey | None = "log_probs",
+ ):
+ ...
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def get_new_version(self, **kwargs):
+ """Returns a new version of the module with altered parameters.
+
+ For instance, the generate parameter can be altered to enable text generation or log-probabilities computation.
+ This is especially useful when one wants to avoid re-initializing the module with a new set of parameters, when the
+ same parameters could be used to gather log-probs.
+
+ Positional arguments are not supported.
+
+ See the class constructor for more details about the parameters.
+ """
+ raise NotImplementedError
+
+ _collector: weakref.ReferenceType[
+ LLMCollector # noqa: F821 # type: ignore
+ ] | None = None
+
+ def register_collector(self, collector: LLMCollector): # noqa: F821 # type: ignore
+ """Registers a weak reference to the container collector.
+
+ This is automatically called by the :class:`~torchrl.collectors.llm.LLMCollector` class.
+ """
+ self._collector = weakref.ref(collector)
+
+ @property
+ def collector(self) -> LLMCollector | None: # noqa: F821 # type: ignore
+ """Returns the collector associated with the module, if it exists."""
+ return self._collector() if self._collector is not None else None
def get_dist(
self,
tensordict: TensorDictBase,
tensordict_out: TensorDictBase | None = None,
+ logits_key: NestedKey = "logits",
+ mask_key: NestedKey | None = None,
as_padded_tensor: bool | None = None,
as_nested_tensor: bool | None = None,
padding_value: float | None = None,
- padding_side: str = "right",
+ padding_side: str = "left",
layout: torch.layout | None = None,
**kwargs,
) -> D.Distribution:
+ """Get distribution from logits/log-probs with optional masking.
+
+ Args:
+ tensordict: Input tensordict
+ tensordict_out: Output tensordict (optional)
+ logits_key: Key for logits/log-probs
+ mask_key: Key for mask (optional).
+ as_padded_tensor: Whether to return padded tensor. Default is False.
+ as_nested_tensor: Whether to return nested tensor. Default is False.
+ padding_value: Value for padding. Default is 0.0 for logits and False for masks.
+ padding_side: Side for padding. Default is left by convention.
+ layout: Tensor layout
+ **kwargs: Additional arguments
+
+ Returns:
+ Distribution (Categorical or LLMMaskedCategorical)
+ """
+ if self.generate:
+ raise NotImplementedError(
+ "get_dist is not implemented for generate=True. "
+ "You can create a new version of this wrapper using the `get_new_version` method."
+ )
+
td_out = self(tensordict.copy())
- # By default, pad and use masked categorical
+
+ # Get logits/log-probs
if as_padded_tensor is None:
as_padded_tensor = as_nested_tensor is not True
if padding_value is None:
padding_value = 0.0
if as_nested_tensor is None:
as_nested_tensor = False
+
logits = td_out.get(
- "logits",
+ logits_key,
as_padded_tensor=as_padded_tensor,
as_nested_tensor=as_nested_tensor,
padding_value=padding_value,
padding_side=padding_side,
layout=layout,
)
- if as_padded_tensor:
- # We can use MaskedCategorical
- dist = MaskedCategorical(
+
+ # Get mask if provided
+ mask = None
+ if mask_key is not None:
+ mask = td_out.get(
+ mask_key,
+ as_padded_tensor=as_padded_tensor,
+ as_nested_tensor=as_nested_tensor,
+ padding_value=False,
+ padding_side=padding_side,
+ layout=layout,
+ )
+ elif as_padded_tensor:
+ # Default mask for padded tensors
+ mask = logits != padding_value
+
+ if mask is not None:
+ dist = LLMMaskedCategorical(
logits=logits,
- mask=logits != padding_value,
- use_cross_entropy=True,
+ mask=mask,
)
+ if not dist._position_level_masking:
+ raise ValueError(
+ "Mask is not a position-level mask. "
+ "This is likely because the mask is not a position-level mask."
+ )
return dist
return Categorical(logits)
+ def _get_dist_with_prompt_mask(
+ self,
+ tensordict: TensorDictBase,
+ tokens_key: NestedKey = ("tokens", "prompt"),
+ logits_key: NestedKey = "logits",
+ # TODO: add a prompt_mask and response_mask in Masks
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
+ padding_side: str = "left",
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution masked to only include response tokens (exclude prompt).
+
+ This is suitable for single-turn scenarios where we want to compute loss
+ only on the generated response, not the input prompt.
+
+ Note: If prompt tokens are not available (e.g., when using history input),
+ this method falls back to using the assistant mask.
+
+ Padding side is left by convention.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ if self.generate:
+ raise NotImplementedError(
+ "get_dist_with_prompt_mask is not implemented for generate=True. "
+ "You can create a new version of this wrapper using the `get_new_version` method."
+ )
+ td_out = self(tensordict.copy())
+
+ # Try to get prompt tokens first
+ if self.pad_output:
+ prompt_tokens = tensordict.get(
+ tokens_key,
+ as_padded_tensor=True,
+ padding_value=-100,
+ padding_side=padding_side,
+ )
+ logits = td_out.get(
+ logits_key,
+ as_padded_tensor=True,
+ padding_value=0.0,
+ padding_side=padding_side,
+ )
+ attention_mask = tensordict.get(
+ attention_mask_key,
+ as_padded_tensor=True,
+ padding_value=False,
+ padding_side=padding_side,
+ )
+ assistant_mask = tensordict.get(
+ assistant_mask_key,
+ as_padded_tensor=True,
+ padding_value=False,
+ padding_side=padding_side,
+ )
+ else:
+ prompt_tokens = tensordict.get(tokens_key, as_list=True)
+ logits = td_out.get(logits_key, as_list=True)
+ attention_mask = td_out.get(attention_mask_key, as_list=True)
+ assistant_mask = td_out.get(assistant_mask_key, as_list=True)
+
+ if prompt_tokens is None:
+ if assistant_mask is None:
+ raise ValueError(
+ f"Assistant mask not found in tensordict at key {assistant_mask_key} (keys: {td_out.keys()})"
+ )
+ if self.pad_output:
+ response_mask = assistant_mask.clone()
+ else:
+ response_mask = [am.clone() for am in assistant_mask]
+ else:
+ if self.pad_output:
+ response_mask = attention_mask.clone()
+ response_mask[..., : prompt_tokens.shape[-1]] = False
+ else:
+ response_mask = []
+ for am, p in _zip_strict(attention_mask, prompt_tokens):
+ am = am.clone()
+ am[..., : p.size(-1)] = False
+ response_mask.append(am)
+
+ if logits is None:
+ raise ValueError(
+ f"Logits not found in tensordict at key {logits_key} (keys: {td_out.keys()})"
+ )
+
+ # Make the response mask using prompt tokens
+ if not self.pad_output:
+ # Check that the lengths of the mask is the same as the logits
+ for m, lg in _zip_strict(response_mask, logits):
+ if m.shape[-1] != lg.shape[-2]:
+ raise ValueError(
+ f"Mask and logits have different lengths: {m.shape[-1]} != {lg.shape[-2]}.\n"
+ f"All the logits shapes: {[lg.shape for lg in logits]}, all the mask shapes: {[m.shape for m in response_mask]}"
+ )
+ logits = pad_sequence(
+ logits, batch_first=True, padding_value=0.0, padding_side=padding_side
+ )
+ response_mask = pad_sequence(
+ response_mask,
+ batch_first=True,
+ padding_value=False,
+ padding_side=padding_side,
+ )
+
+ dist = LLMMaskedCategorical(
+ logits=logits,
+ mask=response_mask.bool(),
+ )
+ if not dist._position_level_masking:
+ raise ValueError(
+ "Mask is not a position-level mask. "
+ "This is likely because the mask is not a position-level mask."
+ )
+ return dist
+
+ def _get_dist_with_assistant_mask(
+ self,
+ tensordict: TensorDictBase,
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
+ logits_key: NestedKey = "logits",
+ padding_side: str = "left",
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution masked to only include assistant tokens.
+
+ This is suitable for multi-turn scenarios where we want to compute loss
+ only on assistant-generated tokens across the entire conversation.
+
+ Padding side is left by convention.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ if self.generate:
+ raise NotImplementedError(
+ "get_dist_with_assistant_mask is not implemented for generate=True. "
+ "You can create a new version of this wrapper using the `get_new_version` method."
+ )
+ td_out = self(tensordict.copy())
+ # Update the tokens key to reflect the tokenized history when querying the log-probs
+ tensordict.update(
+ td_out,
+ keys_to_update=[
+ ("tokens", "full"),
+ ],
+ )
+
+ if self.pad_output:
+ logits = td_out.get(logits_key)
+ assistant_mask = td_out.get(assistant_mask_key)
+ else:
+ logits = td_out.get(
+ logits_key,
+ as_padded_tensor=True,
+ padding_value=0.0,
+ padding_side=padding_side,
+ )
+ assistant_mask = td_out.get(
+ assistant_mask_key,
+ as_padded_tensor=True,
+ padding_value=False,
+ padding_side=padding_side,
+ )
+ if logits is None:
+ raise ValueError(f"Logits not found in tensordict at key {logits_key}")
+ if assistant_mask is None:
+ if self.input_mode != "history":
+ post_msg = "This is likely because the input_mode is not 'history'."
+ else:
+ post_msg = ""
+ raise ValueError(
+ f"Assistant mask not found in tensordict at key {assistant_mask_key}. {post_msg}"
+ )
+
+ dist = LLMMaskedCategorical(
+ logits=logits,
+ mask=assistant_mask,
+ )
+ if not dist._position_level_masking:
+ raise ValueError(
+ "Assistant mask is not a position-level mask. "
+ "This is likely because the assistant mask is not a position-level mask."
+ )
+ return dist
+
+ def _get_dist_with_attention_mask(
+ self,
+ tensordict: TensorDictBase,
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
+ logits_key: NestedKey = "logits",
+ padding_side: str = "left",
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution masked using attention mask.
+
+ This is suitable for generic scenarios where we want to compute loss
+ on all valid tokens (non-padding tokens).
+
+ Padding side is left by convention.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ if self.generate:
+ raise NotImplementedError(
+ "get_dist_with_attention_mask is not implemented for generate=True. "
+ "You can create a new version of this wrapper using the `get_new_version` method."
+ )
+ td_out = self(tensordict.copy())
+ if self.pad_output:
+ logits = td_out.get(logits_key)
+ attention_mask = td_out.get(attention_mask_key)
+ else:
+ logits = td_out.get(
+ logits_key,
+ as_padded_tensor=True,
+ padding_value=0.0,
+ padding_side=padding_side,
+ )
+ attention_mask = td_out.get(
+ attention_mask_key,
+ as_padded_tensor=True,
+ padding_value=False,
+ padding_side=padding_side,
+ )
+
+ if logits is None:
+ raise ValueError(f"Logits not found in tensordict at key {logits_key}")
+ if attention_mask is None:
+ raise ValueError(
+ f"Attention mask not found in tensordict at key {attention_mask_key}"
+ )
+
+ dist = LLMMaskedCategorical(
+ logits=logits,
+ mask=attention_mask,
+ )
+ if not dist._position_level_masking:
+ raise ValueError(
+ "Attention mask is not a position-level mask. "
+ "This is likely because the attention mask is not a position-level mask."
+ )
+ return dist
+
+ def _get_dist_with_custom_mask(
+ self,
+ tensordict: TensorDictBase,
+ mask: torch.Tensor,
+ logits_key: NestedKey = "logits",
+ padding_side: str = "left",
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution with custom mask.
+
+ This allows for completely custom masking logic.
+
+ Padding side is left by convention.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ if self.generate:
+ raise NotImplementedError(
+ "get_dist_with_custom_mask is not implemented for generate=True. "
+ "You can create a new version of this wrapper using the `get_new_version` method."
+ )
+ td_out = self(tensordict.copy())
+ if self.pad_output:
+ logits = td_out.get(logits_key)
+ else:
+ logits = td_out.get(
+ logits_key,
+ as_padded_tensor=True,
+ padding_value=0.0,
+ padding_side=padding_side,
+ )
+
+ if logits is None:
+ raise ValueError(f"Logits not found in tensordict at key {logits_key}")
+
+ dist = LLMMaskedCategorical(
+ logits=logits,
+ mask=mask,
+ )
+ if not dist._position_level_masking:
+ raise ValueError(
+ "Custom mask is not a position-level mask. "
+ "This is likely because the custom mask is not a position-level mask."
+ )
+ return dist
+
+ # Convenience methods for common LLM training scenarios
+ def _get_sft_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
+ """Get distribution suitable for SFT loss (response tokens only).
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ return self._get_dist_with_prompt_mask(tensordict, **kwargs)
+
+ def _get_rlhf_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
+ """Get distribution suitable for RLHF loss (assistant tokens only).
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ return self._get_dist_with_assistant_mask(tensordict, **kwargs)
+
+ def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
+ """Get distribution suitable for generic losses (all tokens).
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ return self._get_dist_with_attention_mask(tensordict, **kwargs)
+
# Sampling is taken care of by the sub-modules
forward = TensorDictSequential.forward
+ def _check_padded(self, val: torch.Tensor) -> torch.Tensor:
+ """Check that a value is a padded tensor."""
+ assert isinstance(
+ val, torch.Tensor
+ ), f"val must be torch.Tensor, got {type(val)}"
+ if not isinstance(val, torch.Tensor):
+ raise ValueError("Not a padded tensor")
+ return val
+
+ def _check_not_padded(
+ self, val: list[torch.Tensor] | torch.Tensor
+ ) -> list[torch.Tensor] | torch.Tensor:
+ """Check that a value is not a padded tensor (i.e., a list of tensors)."""
+ if isinstance(val, torch.Tensor):
+ raise ValueError("Expected a list of tensors - not padded, got a tensor")
+ return val
+
@property
def log_prob_keys(self) -> list[NestedKey]:
return getattr(self, "_log_prob_keys", ["log_probs"])
@@ -69,14 +841,6 @@ def log_prob_keys(self) -> list[NestedKey]:
def log_prob_keys(self, value: list[NestedKey]):
self._log_prob_keys = value
- @property
- def log_prob_key(self) -> NestedKey:
- return self.log_prob_keys[0]
-
- @log_prob_key.setter
- def log_prob_key(self, value: NestedKey) -> None:
- self.log_prob_keys[0] = value
-
@property
def dist_params_keys(self) -> list[NestedKey]:
raise NotImplementedError
@@ -88,5 +852,5 @@ def dist_sample_keys(self) -> list[NestedKey]:
def log_prob(self, data: TensorDictBase, **get_kwargs) -> TensorDictBase:
if not self.generate:
data = self(data)
- return data.get(self.log_prob_key, **get_kwargs)
+ return data.get((self.log_prob_key, "response"), **get_kwargs)
raise RuntimeError("log_prob not callable when generate=True.")
diff --git a/torchrl/modules/llm/policies/transformers_wrapper.py b/torchrl/modules/llm/policies/transformers_wrapper.py
index 98b9f8aae64..c71cf0616dc 100644
--- a/torchrl/modules/llm/policies/transformers_wrapper.py
+++ b/torchrl/modules/llm/policies/transformers_wrapper.py
@@ -4,180 +4,290 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
-from copy import copy
+import contextlib
+from contextlib import nullcontext
+from copy import copy
from typing import Literal
import torch
from tensordict import (
lazy_stack,
- LazyStackedTensorDict,
- NestedKey,
+ MetaData,
+ NonTensorStack,
set_list_to_stack,
TensorDict,
TensorDictBase,
)
-from tensordict.utils import _zip_strict
+from tensordict.utils import _zip_strict, NestedKey
+from torch import distributions as D
from torch.nn.utils.rnn import pad_sequence
-from torchrl.modules.llm.policies.common import CategoricalSequential
+from torchrl.modules.llm.policies.common import (
+ ChatHistory,
+ LLMWrapperBase,
+ LogProbs,
+ Masks,
+ Text,
+ Tokens,
+)
from torchrl.modules.utils.utils import _unpad_tensors
-class TransformersWrapper(CategoricalSequential):
+class TransformersWrapper(LLMWrapperBase):
"""A wrapper class for Hugging Face Transformers models, providing a consistent interface for text generation and log probability computation.
- This class handles both text and token inputs, enabling text generation and log probability computation based on
- the specified configuration. Unlike vLLM, Transformers require padded tensors for input and output sequences.
+ This class is a subclass of :class:`~torchrl.modules.llm.policies.LLMWrapperBase` and provides a unified API for handling different input modalities
+ (history, text, tokens) with consistent output structure using :class:`~tensordict.TensorClass` objects.
Args:
- model (transformers.LLM): The Hugging Face Transformers model to wrap.
+ model (transformers.AutoModelForCausalLM | str): The Hugging Face Transformers model to wrap.
+ If a string, it will be passed to `transformers.AutoModelForCausalLM.from_pretrained`.
Keyword Args:
- return_log_probs (bool | None, optional): Whether to return log probabilities of the generated tokens.
- Defaults to `None`.
- tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | None, optional): The tokenizer to use for
- encoding and decoding text. If `None`, the tokenizer associated with the model will be used. Defaults to
- `None`.
- from_text (bool, optional): Indicates whether the input is in text format. If `True`, the input is expected to
- be text that will be tokenized. If `False`, the input is expected to be token sequences. Defaults to `True`.
-
- .. note:: If `from_text` is `True`, the input text can be provided in the `"text"` key or in the `"history"` key.
- If using the `"history"` key, the history will be parsed from a :class:`~torchrl.data.llm.History` object to a
- text string using the tokenizer.
-
- device (torch.device | None, optional): The device to use for computation. If `None`, the default device will
- be used. Defaults to `None`.
- generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on
- the input. If `False`, only log probabilities will be computed for the response tokens/text. Defaults to `True`.
- generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. These
- arguments can control aspects of the generation process, such as temperature and top-k sampling. Defaults
- to `None`.
-
- .. note:: Sampling params can be overwritten at runtime using the kwargs of the forward method.
- See `the full list of accepted keyword arguments here `__.
-
- tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. These arguments can
- control aspects of the tokenization process, such as padding and truncation. Defaults to `None`.
- pad_output (bool, optional): Whether to pad the output sequences to a uniform length. Transformers require
- `pad_output=True`, and the output sequences will be padded and represented as tensors. Defaults to `True`.
- inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place
- operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be
- created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will
- conserve type, batch-size, and device). Defaults to `True`.
- chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when
- applying the chat template to the history. Defaults to `None`.
+ tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | str | None, optional): The tokenizer to use for
+ encoding and decoding text. If `None`, the tokenizer associated with the model will be used.
+ If a string, it will be passed to `transformers.AutoTokenizer.from_pretrained`. Defaults to `None`.
+ input_mode (str, optional): The input modality to use. Must be one of `"history"`, `"text"`, or `"tokens"`.
+ Defaults to `"history"`.
+ input_key (str | None, optional): The key for the input data. If `None`, defaults to
+ - `("history", "prompt")` for `"history"` when `generate=True`, `("history", "full")` for `"history"` when `generate=False`
+ - `("text", "prompt")` for `"text"` when `generate=True`, `("text", "full")` for `"text"` when `generate=False`
+ - `("tokens", "prompt")` for `"tokens"` when `generate=True`, `("tokens", "full")` for `"tokens"` when `generate=False`
+ attention_mask_key (str, optional): The key for attention masks (used in `"tokens"` mode). Defaults to `"attention_mask"`.
+
+ .. warning:: This argument is under development and may change in the future.
+
+ generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on the input.
+ If `False`, only log probabilities will be computed. Defaults to `True`.
+ return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `False`.
+ generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. Defaults to `None`.
+ tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. Defaults to `None`.
+ pad_output (bool, optional): Whether to pad the output sequences to a uniform length. Transformers require `pad_output=True`, and the output
+ sequences will be padded and represented as tensors. Defaults to `False`.
+ inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place operations. Defaults to `True`.
+ device (torch.device | None, optional): The device to use for computation. Defaults to `None`.
+ layout (torch.layout | None, optional): The layout to use for the output tensors when `pad_output=False`. Defaults to `torch.strided`.
+ num_samples (int | None, optional): The number of samples to generate. Defaults to `None` (one sample, and no batch-dimension for it).
+ Can also be set via the `generate_kwargs["num_return_sequences"] = value` argument. Requires the "do_sample" argument to be set to `True` in `generate_kwargs`.
+ chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when applying the chat
+ template to the history. Defaults to `None`. For `input_mode="history"` only.
chat_template (str | None, optional): The chat template to use when applying the chat template to the history.
- Defaults to `None`.
-
- .. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
- required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
- tokenization and padding.
+ Defaults to `None`. For `input_mode="history"` only.
+ log_probs_key (NestedKey | None, optional): The key for the log probabilities :class:`~torchrl.modules.llm.policies.LogProbs` object. Defaults to `"log_probs"`.
+ text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`.
+ tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`.
+ masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`.
+ history_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.ChatHistory` object. Defaults to `"history"`.
Input Keys:
-
- - If `from_text` is `True`:
-
- - `"text"`: The input text to be tokenized.
- - `"text_response"`: the response text (if `generate=False` as the log probabilities are computed for the response.)
-
- - If `from_text` is `False`:
-
- - "tokens": The input token sequences.
- - "attention_mask": The attention mask for the tokens.
- - "tokens_response": The response token sequences (if `generate=False` as the log probabilities are
- computed for the response.)
+ The input key depends on both `input_mode` and `generate`:
+ - If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`)
+ - If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`)
+ - If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`)
+ - If `input_mode="text"` and `generate=False`: `input_key` (defaults to `("text", "full")`)
+ - If `input_mode="tokens"` and `generate=True`: `input_key` (defaults to `("tokens", "prompt")`)
+ - If `input_mode="tokens"` and `generate=False`: `input_key` (defaults to `("tokens", "full")`)
Output Keys:
-
- - `"tokens_response"`: The generated token sequences.
- - `"log_probs"`: The log probabilities of the generated tokens (if `return_log_probs` is `True`).
- - `"text_response"`: The generated text (if `from_text` is `True` and `generate` is `True`).
+ The output keys are automatically determined based on the input_mode:
+ - **Tokens**: Always returned (`tokens_key`, defaults to `"tokens"`)
+ - **Text**: Returned for `"text"` and `"history"` modes (`text_key`, defaults to `"text"`)
+ - **History**: Returned only for `"history"` mode (`history_key`, defaults to `"history"`)
+ - **Masks**: Always returned (`masks_key`, defaults to `"masks"`)
+ - **Log Probs**: Returned when `return_log_probs=True` (`log_probs_key`, defaults to `"log_probs"`)
+
+ Example output structure for `input_mode="history"`:
+ ```
+ TensorDict(
+ text=Text(prompt=..., response=..., full=...),
+ masks=Masks(all_attention_mask=..., all_assistant_mask=...),
+ tokens=Tokens(prompt=..., response=..., full=...),
+ log_probs=LogProbs(prompt=..., response=..., full=...),
+ history=ChatHistory(prompt=..., response=..., full=...)
+ )
+ ```
Example:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
+ >>> from torchrl.data.llm import History
+ >>> from torchrl.modules.llm.policies import ChatHistory
+ >>>
>>> model = AutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
+ >>>
+ >>> # History input (recommended for RL environments)
>>> wrapper = TransformersWrapper(
... model,
... tokenizer=tokenizer,
- ... from_text=True,
- ... generate=True
+ ... input_mode="history",
+ ... generate=True,
+ ... return_log_probs=True
... )
- >>> input_data = TensorDict({"text": ["Hello, world!", "This is another text"]}, batch_size=1)
- >>> output_data = wrapper(input_data)
- >>> print(output_data["text_response"])
-
- .. seealso:: :func:`~torchrl.modules.vLLMWrapper` for a similar interface using vLLM.
-
+ >>>
+ >>> history = History.from_chats([[
+ ... {"role": "user", "content": "Hello"},
+ ... {"role": "assistant", "content": "Hi there!"}
+ ... ]])
+ >>> chat_history = ChatHistory(prompt=history)
+ >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
+ >>> print(result["text"].response) # Generated text
+ >>> print(result["log_probs"].response) # Log probabilities
+ >>> print(result["history"].response) # History with response
+
+ Attributes:
+ collector: The collector associated with the module, if it exists.
+
+ .. seealso::
+ - :class:`~torchrl.modules.llm.policies.LLMWrapperBase` (see :ref:`ref_categorical_sequential`)
+ - :class:`~torchrl.modules.llm.policies.vLLMWrapper` (see :ref:`ref_vllm_wrapper`)
"""
- text_key: NestedKey = ("text",)
- history_key: NestedKey = ("history",)
- token_key: NestedKey = ("tokens",)
- token_response_key: NestedKey = ("tokens_response",)
- text_response_key: NestedKey = ("text_response",)
- attention_mask_key: NestedKey = ("attention_mask",)
-
def __init__(
self,
- model: transformers.LLM, # noqa
- # noqa
+ model,
*,
- return_log_probs: bool | None = None,
- tokenizer: transformers.tokenization_utils.PreTrainedTokenizer # noqa
- | None = None,
- # noqa
- from_text: bool = True,
- device: torch.device | None = None,
+ tokenizer=None,
+ input_mode: str = "history",
+ input_key: str | None = None,
+ attention_mask_key: str = "attention_mask",
generate: bool = True,
generate_kwargs: dict | None = None,
tokenizer_kwargs: dict | None = None,
- pad_output: bool = True,
- inplace: Literal[True, False, "empty"] | None = True,
+ pad_output: bool = False,
+ inplace: Literal[True, False, "empty"] | None = None,
+ device: torch.device | None = None,
+ layout: torch.layout | None = None,
+ num_samples: int | None = None,
chat_template_name: Literal["chatml_format", "qwen"] | None = None,
chat_template: str | None = None,
+ return_log_probs: bool | None = None,
+ history_key: NestedKey | None = "history",
+ text_key: NestedKey | None = "text",
+ tokens_key: NestedKey | None = "tokens",
+ masks_key: NestedKey | None = "masks",
+ log_probs_key: NestedKey | None = "log_probs",
):
super().__init__()
+ if isinstance(model, str):
+ from transformers import AutoModelForCausalLM
+
+ model = AutoModelForCausalLM.from_pretrained(model)
+
+ if isinstance(tokenizer, str):
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
+
+ # Validate input_mode
+ if input_mode not in ["history", "text", "tokens"]:
+ raise ValueError(
+ f"input_mode must be one of 'history', 'text', 'tokens'. Got '{input_mode}'"
+ )
+
self.model = model
- self.from_text = from_text
- if device is not None:
- device = torch.device(device)
- self._device = device
+ self.input_mode = input_mode
+ self.attention_mask_key = attention_mask_key
self.generate = generate
- self.inplace = inplace
+
+ # Auto-determine what to return based on input mode
+ self.return_history = input_mode in ("history",)
+ self.return_text = input_mode in ("text", "history")
+ self.return_tokens = input_mode in ("tokens", "history", "text")
+ self.return_masks = True
+ if return_log_probs is False and not generate:
+ raise ValueError("return_log_probs must be True when generate=False.")
+ return_log_probs = (
+ True
+ if (return_log_probs is None and generate) or (not generate)
+ else bool(return_log_probs)
+ )
+ self.return_log_probs = return_log_probs
+
+ self.history_key = history_key
+ self.text_key = text_key
+ self.tokens_key = tokens_key
+ self.masks_key = masks_key
+ self.log_probs_key = log_probs_key
+ if not isinstance(pad_output, bool):
+ raise ValueError("pad_output must be a boolean")
self.pad_output = pad_output
+ self._device = device
+ if not pad_output and layout is None:
+ layout = torch.strided
+ self.layout = layout
padding_value = None
- self.chat_template_name = chat_template_name
- self.chat_template = chat_template
+ # Auto-determine input_key if not provided
+
+ # Set input keys based on mode and generate parameter
+ if input_mode == "history":
+ if generate:
+ self.in_keys = [
+ ("history", "prompt") if input_key is None else input_key
+ ]
+ else:
+ self.in_keys = [("history", "full") if input_key is None else input_key]
+ elif input_mode == "text":
+ if generate:
+ self.in_keys = [("text", "prompt") if input_key is None else input_key]
+ else:
+ self.in_keys = [("text", "full") if input_key is None else input_key]
+ elif input_mode == "tokens":
+ if generate:
+ self.in_keys = [
+ ("tokens", "prompt") if input_key is None else input_key
+ ]
+ else:
+ self.in_keys = [("tokens", "full") if input_key is None else input_key]
+ self.input_key = self.in_keys[0]
+
+ # Set output keys based on auto-determined return flags
+ self.out_keys = []
+ if self.return_text:
+ self.out_keys.append(self.text_key)
+ if self.return_masks:
+ self.out_keys.append(self.masks_key)
+ if self.return_tokens:
+ self.out_keys.append(self.tokens_key)
+ if self.return_log_probs:
+ self.out_keys.append(self.log_probs_key)
+ if self.return_history:
+ self.out_keys.append(self.history_key)
+
+ # Tokenizer setup
if not tokenizer_kwargs:
tokenizer_kwargs = {}
+ else:
+ tokenizer_kwargs = dict(tokenizer_kwargs)
if not tokenizer_kwargs.setdefault("return_attention_mask", True):
- raise RuntimeError
-
- # If we don't pad, we use lists
- if not self.pad_output:
- raise NotImplementedError("transformers requires `pad_output=True`.")
- if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt":
- raise RuntimeError
- if tokenizer_kwargs.setdefault("padding", self.pad_output) not in (
- self.pad_output,
- ):
- raise RuntimeError
+ raise RuntimeError("return_attention_mask must be True")
+
+ # We always pad, so we always return tensors
+ return_tensors = "pt"
+ tokenizer_kwargs.setdefault("padding", True)
+ if return_tensors:
+ if (
+ tokenizer_kwargs.setdefault("return_tensors", return_tensors)
+ != return_tensors
+ ):
+ raise RuntimeError
+
+ # We always pad atm
if tokenizer_kwargs.setdefault("padding_side", "left") != "left":
raise RuntimeError
self.tokenizer_kwargs = tokenizer_kwargs
- if (pad_output or (from_text and not generate)) and tokenizer is None:
- # We need a tokenizer if we pad or when using text inputs with generate=False
- # The latter case is due to the fact that we want the log-probs for the response only,
- # but if the response is presented as a text we have to tokenize the whole prompt + response and
- # identify where the prompt ends and where the response starts.
+
+ # Get tokenizer if needed
+ if (
+ pad_output or (input_mode in ["text", "history"] and not generate)
+ ) and tokenizer is None:
tokenizer = model.get_tokenizer()
self.tokenizer = tokenizer
- if tokenizer is not None and (
+
+ if self.tokenizer is not None and (
not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None
):
self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -185,40 +295,173 @@ def __init__(
padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
self.padding_value = padding_value
+ # Generate kwargs setup
if generate_kwargs is None:
generate_kwargs = {}
else:
generate_kwargs = dict(generate_kwargs)
- if not generate:
- # TODO
- if return_log_probs in (None, True):
- return_log_probs = True
- else:
+ self.num_samples = num_samples
+ if (
+ generate_kwargs.get("num_return_sequences", 1) > 1
+ or num_samples is not None
+ ):
+ if inplace in (True, "empty"):
raise ValueError(
- "return_log_probs must be True or None when generate=False."
+ "inplace must be False (or None) when generating more than one sample."
)
- elif return_log_probs in (None, False):
- return_log_probs = False
- self.return_log_probs = return_log_probs
+ if inplace is None:
+ inplace = False
+ if (
+ generate_kwargs.get("num_return_sequences", 1) > 1
+ and num_samples is not None
+ and generate_kwargs.get("num_return_sequences", 1) != num_samples
+ ):
+ raise ValueError("num_samples differs from generate_kwargs['n'].")
+ elif num_samples is None:
+ self.num_samples = generate_kwargs.get("num_return_sequences", 1)
+ generate_kwargs["num_return_sequences"] = self.num_samples
+ elif inplace is None:
+ inplace = True
+
+ self.inplace = inplace
+
+ if not generate:
+ # We want only the log-probs, we generate a single token (that we then discard)
+ # and retrieve the prompt log-probs
+ generate_kwargs["max_tokens"] = 1
generate_kwargs.setdefault("tokenizer", self.tokenizer)
generate_kwargs.setdefault("output_logits", self.return_log_probs)
generate_kwargs.setdefault("return_dict_in_generate", True)
- if not generate:
- generate_kwargs.setdefault("return_dict_in_generate", True)
self.generate_kwargs = generate_kwargs
- if from_text:
- self.in_keys = [self.text_key]
- else:
- self.in_keys = [self.token_key, self.attention_mask_key]
- self.out_keys = [self.token_response_key]
- if from_text:
- self.out_keys += [self.text_response_key, self.token_key]
- if self.return_log_probs:
- self.out_keys += [self.log_prob_key, "logits"]
+ # Additional transformers-specific settings
+ self.chat_template_name = chat_template_name
+ self.chat_template = chat_template
+
+ # Flag to track when we're in a get_dist call
+ self._in_get_dist_call = False
+
+ def get_new_version(self, **kwargs):
+ """Returns a new version of the module with altered parameters.
+
+ For instance, the generate parameter can be altered to enable text generation or log-probabilities computation.
+ This is especially useful when one wants to avoid re-initializing the module with a new set of parameters, when the
+ same parameters could be used to gather log-probs.
+
+ Positional arguments are not supported.
+
+ See the class constructor for more details about the parameters.
+ """
+ # Build the constructor arguments by using current values for missing parameters
+ constructor_kwargs = {}
+
+ # Model is always required
+ constructor_kwargs["model"] = kwargs.get("model", self.model)
+
+ # Check for each parameter and use current value if not provided
+ if "tokenizer" in kwargs:
+ constructor_kwargs["tokenizer"] = kwargs["tokenizer"]
+ elif hasattr(self, "tokenizer"):
+ constructor_kwargs["tokenizer"] = self.tokenizer
+
+ if "input_mode" in kwargs:
+ constructor_kwargs["input_mode"] = kwargs["input_mode"]
+ elif hasattr(self, "input_mode"):
+ constructor_kwargs["input_mode"] = self.input_mode
+
+ if "input_key" in kwargs:
+ constructor_kwargs["input_key"] = kwargs["input_key"]
+ elif hasattr(self, "input_key"):
+ constructor_kwargs["input_key"] = self.input_key
+
+ if "attention_mask_key" in kwargs:
+ constructor_kwargs["attention_mask_key"] = kwargs["attention_mask_key"]
+ elif hasattr(self, "attention_mask_key"):
+ constructor_kwargs["attention_mask_key"] = self.attention_mask_key
+
+ if "generate" in kwargs:
+ constructor_kwargs["generate"] = kwargs["generate"]
+ elif hasattr(self, "generate"):
+ constructor_kwargs["generate"] = self.generate
+
+ if "generate_kwargs" in kwargs:
+ constructor_kwargs["generate_kwargs"] = kwargs["generate_kwargs"]
+ elif hasattr(self, "generate_kwargs"):
+ constructor_kwargs["generate_kwargs"] = self.generate_kwargs
+
+ if "pad_output" in kwargs:
+ constructor_kwargs["pad_output"] = kwargs["pad_output"]
+ elif hasattr(self, "pad_output"):
+ constructor_kwargs["pad_output"] = self.pad_output
+
+ if "tokenizer_kwargs" in kwargs:
+ constructor_kwargs["tokenizer_kwargs"] = kwargs["tokenizer_kwargs"]
+ elif hasattr(self, "tokenizer_kwargs"):
+ constructor_kwargs["tokenizer_kwargs"] = self.tokenizer_kwargs
+ if (
+ "pad_output" in kwargs
+ and kwargs.get("pad_output")
+ != constructor_kwargs["tokenizer_kwargs"]["padding"]
+ ):
+ constructor_kwargs["tokenizer_kwargs"]["padding"] = kwargs.get(
+ "pad_output"
+ )
+
+ if "inplace" in kwargs:
+ constructor_kwargs["inplace"] = kwargs["inplace"]
+ elif hasattr(self, "inplace"):
+ constructor_kwargs["inplace"] = self.inplace
+
+ if "device" in kwargs:
+ constructor_kwargs["device"] = kwargs["device"]
+ elif hasattr(self, "_device"):
+ constructor_kwargs["device"] = self._device
+
+ if "layout" in kwargs:
+ constructor_kwargs["layout"] = kwargs["layout"]
+ elif hasattr(self, "layout"):
+ constructor_kwargs["layout"] = self.layout
+
+ if "num_samples" in kwargs:
+ constructor_kwargs["num_samples"] = kwargs["num_samples"]
+ elif hasattr(self, "num_samples"):
+ constructor_kwargs["num_samples"] = self.num_samples
+
+ if "chat_template_name" in kwargs:
+ constructor_kwargs["chat_template_name"] = kwargs["chat_template_name"]
+ elif hasattr(self, "chat_template_name"):
+ constructor_kwargs["chat_template_name"] = self.chat_template_name
+
+ if "chat_template" in kwargs:
+ constructor_kwargs["chat_template"] = kwargs["chat_template"]
+ elif hasattr(self, "chat_template"):
+ constructor_kwargs["chat_template"] = self.chat_template
+
+ if "text_key" in kwargs:
+ constructor_kwargs["text_key"] = kwargs["text_key"]
+ elif hasattr(self, "text_key"):
+ constructor_kwargs["text_key"] = self.text_key
+
+ if "tokens_key" in kwargs:
+ constructor_kwargs["tokens_key"] = kwargs["tokens_key"]
+ elif hasattr(self, "tokens_key"):
+ constructor_kwargs["tokens_key"] = self.tokens_key
+
+ if "masks_key" in kwargs:
+ constructor_kwargs["masks_key"] = kwargs["masks_key"]
+ elif hasattr(self, "masks_key"):
+ constructor_kwargs["masks_key"] = self.masks_key
+
+ if "log_probs_key" in kwargs:
+ constructor_kwargs["log_probs_key"] = kwargs["log_probs_key"]
+ elif hasattr(self, "log_probs_key"):
+ constructor_kwargs["log_probs_key"] = self.log_probs_key
+
+ # Create and return new instance
+ return type(self)(**constructor_kwargs)
@set_list_to_stack(True)
def forward(
@@ -251,223 +494,645 @@ def forward(
else:
cfg = None
- out = LazyStackedTensorDict(
- *[
+ if self.num_samples is not None:
+ out = (
TensorDict(
- device=tensordict.device, batch_size=tensordict.batch_size[1:]
+ device=tensordict.device,
+ batch_size=(
+ tensordict.batch_size[0],
+ self.num_samples,
+ *tensordict.batch_size[1:],
+ ),
)
- for _ in range(tensordict.shape[0])
- ]
- )
- if self.from_text:
+ .to_lazystack(1)
+ .to_lazystack(0)
+ )
+ else:
+ out = TensorDict(
+ device=tensordict.device, batch_size=tensordict.batch_size
+ ).to_lazystack(0)
+
+ if self.input_mode == "history":
if self.generate:
- out = self._from_transformers_generate_text(
- tensordict, out=out, cfg=cfg
- )
+ out = self._from_transformers_generate_history(tensordict, cfg, out)
else:
- out = self._from_transformers_logprobs_text(
- tensordict, out=out, cfg=cfg
- )
- else:
+ out = self._from_transformers_logprobs_history(tensordict, cfg, out)
+ elif self.input_mode == "text":
if self.generate:
- out = self._from_transformers_generate_tokens(
- tensordict, out=out, cfg=cfg
- )
+ out = self._from_transformers_generate_text(tensordict, cfg, out)
else:
- out = self._from_transformers_logprobs_tokens(
- tensordict, out=out, cfg=cfg
- )
+ out = self._from_transformers_logprobs_text(tensordict, cfg, out)
+ elif self.input_mode == "tokens":
+ if self.generate:
+ out = self._from_transformers_generate_tokens(tensordict, cfg, out)
+ else:
+ out = self._from_transformers_logprobs_tokens(tensordict, cfg, out)
+
if _source_device:
out = out.to(_source_device)
if tensordict_out is None:
if self.inplace is True:
+ # The output is the input
tensordict_out = tensordict
elif self.inplace is False:
- tensordict_out = TensorDict()
+ # The output is the new structure
+ tensordict_out = out
elif self.inplace == "empty":
+ # The output is empty
tensordict_out = tensordict.empty()
- if tensordict_out is not None:
- result = tensordict_out
+ if tensordict_out is not None and tensordict_out is not out:
+ result = tensordict_out.exclude(*self.out_keys, inplace=True)
result.update(out, keys_to_update=self.out_keys)
- else:
+ elif tensordict_out is out:
+ result = out.select(*self.out_keys)
+ elif self.inplace:
result = out
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
- return tensordict.update(result, keys_to_update=keys)
+ result = tensordict.exclude(*self.out_keys, inplace=True).update(
+ result, keys_to_update=keys
+ )
+ else:
+ result = out
return result
- def _from_transformers_generate_text(self, td, out, cfg=None):
- pad_val = self.tokenizer.pad_token_id
+ def _from_transformers_generate_history(self, td, cfg, out) -> TensorDictBase:
+ """Generate text from history input."""
+ from torchrl.data.llm import History
- text = td.get(self.text_key)
- if text is None:
- # Fallback on history parsing
- history = td.get(self.history_key)
- if history is None:
- raise ValueError(
- "No text or history provided to the TransformersWrapper."
+ # Validate input
+ if self.input_key not in td:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for history input mode, "
+ f"but found keys: {list(td.keys())}"
+ )
+
+ history = td.get(self.input_key)
+ if not isinstance(history, History):
+ raise TypeError(
+ f"Expected History object for '{self.input_key}', got {type(history)}"
+ )
+
+ # Apply chat template
+ tokenizer_kwargs = {}
+ if self.chat_template_name is not None:
+ tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
+ if self.chat_template is not None:
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
+ tokenizer_kwargs.setdefault("add_generation_prompt", True)
+ text_prompt = history.apply_chat_template(
+ tokenizer=self.tokenizer, **tokenizer_kwargs
+ )
+ if not isinstance(text_prompt, list):
+ raise ValueError(
+ f"Expected list of text for history input, got {type(text_prompt)}"
+ )
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", False)
+ tokenizer_kwargs.setdefault("tokenize", True)
+ tokenizer_kwargs.setdefault("padding", False)
+ tokenizer_kwargs.setdefault("return_dict", True)
+ response_struct = history.apply_chat_template(
+ tokenizer=self.tokenizer, **tokenizer_kwargs
+ )
+ tokens_prompt_padded = response_struct.get(
+ "input_ids",
+ as_padded_tensor=True,
+ padding_value=self.padding_value,
+ padding_side="left",
+ )
+ attention_mask_prompt_padded = response_struct.get(
+ "attention_mask",
+ as_padded_tensor=True,
+ padding_value=0,
+ padding_side="left",
+ )
+
+ if attention_mask_prompt_padded is None:
+ attention_mask_prompt_padded = (
+ tokens_prompt_padded != self.tokenizer.pad_token_id
+ )
+
+ result = self._generate_from_tokens(
+ tokens_prompt_padded, attention_mask_prompt_padded, cfg, out
+ )
+
+ # Generate using text path
+ if self.pad_output:
+ result[(self.tokens_key, "prompt")] = (
+ tokens_prompt_padded
+ if not self.num_samples
+ else tokens_prompt_padded.unsqueeze(1).repeat(1, self.num_samples, 1)
+ )
+ else:
+ tokens_prompt_unpadded = response_struct.get(
+ "input_ids",
+ as_nested_tensor=True,
+ )
+ if not self.num_samples:
+ result[(self.tokens_key, "prompt")] = tokens_prompt_unpadded
+ else:
+ for r in result.unbind(1):
+ r[(self.tokens_key, "prompt")] = tokens_prompt_unpadded
+
+ text_result = Text._from_tensordict(result.empty())
+ result.set(self.text_key, text_result)
+ if not self.num_samples:
+ text_result.prompt = text_prompt
+ else:
+ for r in result.unbind(1):
+ r[self.text_key, "prompt"] = text_prompt
+ with result.view(-1) as result_flat:
+ if self.pad_output:
+ tokens_full_padded = result_flat.get(
+ (self.tokens_key, "full"),
+ as_padded_tensor=True,
+ padding_side="right",
+ padding_value=self.padding_value,
)
- tokenizer_kwargs = {}
- if self.chat_template_name is not None:
- tokenizer_kwargs.setdefault(
- "chat_template_name", self.chat_template_name
+ if tokens_full_padded is None:
+ raise ValueError("tokens_full_padded is None")
+ text_full = self.tokenizer.batch_decode(
+ tokens_full_padded, skip_special_tokens=False
+ )
+ else:
+ tokens_full_unpadded = result_flat.get(
+ (self.tokens_key, "full"), as_list=True
+ )
+ if tokens_full_unpadded is None:
+ raise ValueError("tokens_full_unpadded is None")
+ text_full = self.tokenizer.batch_decode(
+ tokens_full_unpadded, skip_special_tokens=False
)
- if self.chat_template is not None:
- tokenizer_kwargs.setdefault("chat_template", self.chat_template)
- tokenizer_kwargs.setdefault("add_generation_prompt", False)
- text = history.apply_chat_template(
+ text_prompt = result_flat[self.text_key, "prompt"]
+ text_response = [
+ txt[len(prompt) :]
+ for txt, prompt in _zip_strict(text_full, text_prompt)
+ ]
+ result_flat.set((self.text_key, "full"), text_full)
+ result_flat.set((self.text_key, "response"), text_response)
+ # Now parse the full text back to a history object, and use the extra history objects
+ # as response
+ history_chat = ChatHistory._from_tensordict(result.empty())
+ if self.num_samples is None:
+ history_chat.prompt = history
+ else:
+ for h in history_chat.unbind(1):
+ h.prompt = history
+ with history_chat.view(-1) as history_chat_flat:
+ history_chat_flat.full = full_histories = History.from_text(text_full)
+ prompt_histories = history_chat_flat.prompt
+ # iterate over batch
+ h_responses = []
+ for h_full, h_prompt in _zip_strict(
+ full_histories.unbind(0), prompt_histories.unbind(0)
+ ):
+ if h_full.shape[0] <= h_prompt.shape[0]:
+ raise RuntimeError("Full history is shorter than prompt history")
+ h_responses.append(h_full[h_prompt.shape[0] :])
+ history_chat_flat.response = torch.stack(h_responses)
+ result.set(self.history_key, history_chat)
+ return result
+
+ def _from_transformers_logprobs_history(self, td, cfg, out):
+ """Compute log-probs from history input."""
+ from torchrl.data.llm import History
+
+ # Validate input
+ if self.input_key not in td:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for history input mode, "
+ f"but found keys: {list(td.keys())}"
+ )
+
+ history = td.get(self.input_key)
+ if not isinstance(history, History):
+ raise TypeError(
+ f"Expected History object for '{self.input_key}', got {type(history)}"
+ )
+
+ # Apply chat template
+ tokenizer_kwargs = {}
+ if self.chat_template_name is not None:
+ tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
+ if self.chat_template is not None:
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
+ text_full = history.apply_chat_template(
+ tokenizer=self.tokenizer, **tokenizer_kwargs
+ )
+
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
+ tokenizer_kwargs.setdefault("tokenize", True)
+ tokenizer_kwargs.setdefault("padding", False)
+ tokenizer_kwargs.setdefault("return_dict", True)
+
+ with torch.device(self._device) if self._device is not None else nullcontext():
+ response_tokens = history.apply_chat_template(
tokenizer=self.tokenizer, **tokenizer_kwargs
)
- if not isinstance(text, (list, str)):
+ if not isinstance(response_tokens, TensorDictBase):
+ raise ValueError(
+ f"Expected TensorDictBase for history input, got {type(response_tokens)}"
+ )
+ result = self._logprobs_from_history_tokens(response_tokens, cfg, out)
+ text_result = Text._from_tensordict(result.empty())
+ result.set(self.text_key, text_result)
+ result[self.text_key, "full"] = text_full
+ result.set(self.history_key, ChatHistory(full=history))
+ return result
+
+ def _cat_text(self, text, response_text):
+ """Concatenate text and response text."""
+ if isinstance(text, list):
+ return [self._cat_text(t, t_) for t, t_ in _zip_strict(text, response_text)]
+ else:
+ return text + response_text
+
+ def _generate_from_text(self, text, cfg, out) -> TensorDictBase:
+ """Generate text from text input."""
+ pad_val = self.tokenizer.pad_token_id
+
+ # Convert text to list format
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list):
text = text.tolist()
- tokens_in = self.tokenizer(text, **self.tokenizer_kwargs)
+
+ tokenizer_kwargs = dict(self.tokenizer_kwargs)
+ tokenizer_kwargs.setdefault("padding", True)
+
+ with torch.device(
+ self._device
+ ) if self._device is not None else contextlib.nullcontext():
+ tokens_in = self.tokenizer(text, **tokenizer_kwargs)
if self._device is not None:
tokens_in = tokens_in.to(self._device)
- input_ids = tokens_in["input_ids"]
- attention_mask = tokens_in["attention_mask"]
+ # We are going to map this tokens_in to a tensordict to facilitate the padding in case we need it
+ tokens_in = dict(tokens_in)
+ for k, v in dict(tokens_in).items():
+ if isinstance(v, list):
+ if isinstance(v[0], torch.Tensor):
+ v = torch.nested.nested_tensor(v)
+ else:
+ v = torch.nested.nested_tensor([torch.tensor(t) for t in v])
+ tokens_in[k] = v
+ tokens_in = (
+ TensorDict(batch_size=tokens_in["input_ids"].size(0))
+ .to_lazystack(0)
+ .update(tokens_in)
+ )
+ tokens_prompt_padded = tokens_in.get(
+ "input_ids",
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=pad_val,
+ )
+ attention_mask_prompt_padded = tokens_in.get(
+ "attention_mask",
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=0,
+ )
+
if cfg is not None:
kwargs = copy(self.generate_kwargs)
kwargs["generation_config"] = cfg
else:
kwargs = self.generate_kwargs
+
tokens_out = self.model.generate(
- input_ids=input_ids, attention_mask=attention_mask, **kwargs
+ input_ids=tokens_prompt_padded,
+ attention_mask=attention_mask_prompt_padded,
+ **kwargs,
+ )
+ tokens_full_padded = tokens_out["sequences"]
+ tokens_response_padded = tokens_full_padded[
+ ..., tokens_prompt_padded.shape[-1] :
+ ]
+
+ attention_mask_response_padded = tokens_response_padded != pad_val
+ if self.num_samples:
+ attention_mask_full_padded = torch.cat(
+ [
+ attention_mask_prompt_padded.repeat_interleave(
+ self.num_samples, dim=0
+ ),
+ attention_mask_response_padded,
+ ],
+ dim=-1,
+ )
+ else:
+ attention_mask_full_padded = torch.cat(
+ [attention_mask_prompt_padded, attention_mask_response_padded], dim=-1
+ )
+ tokens_response_unpadded = _unpad_tensors(
+ tokens_response_padded, attention_mask_response_padded, as_nested=False
)
- sequences = tokens_out["sequences"]
- sequences = sequences[..., input_ids.shape[-1] :]
- mask_sequences = sequences != pad_val
- sequences = _unpad_tensors(sequences, mask_sequences, as_nested=False)
if self.return_log_probs:
+ # These are only for the new tokens, not for the prompt - to get that, we'd need to run the forward pass again
logits = torch.stack(list(tokens_out["logits"]), 1)
- logits = _unpad_tensors(logits, mask_sequences, as_nested=False)
log_probs, logits = self._log_probs_generate(
- sequences, logits, pad_val=-100
+ tokens_response_padded, logits, pad_val=-100, pad=False
)
+
response_text = self.tokenizer.batch_decode(
- sequences, skip_special_tokens=False
+ tokens_response_unpadded, skip_special_tokens=False
)
- out.set(self.token_response_key, sequences)
- out.set(
- self.token_key, _unpad_tensors(input_ids, attention_mask, as_nested=False)
- )
- out.set(self.text_response_key, list(response_text))
- out.set(
- self.attention_mask_key,
- _unpad_tensors(attention_mask, attention_mask, as_nested=False),
- )
- if self.return_log_probs:
- out.set(
- self.log_prob_key,
- _unpad_tensors(log_probs, mask_sequences, as_nested=False),
+
+ # Build output TensorClass objects
+ if self.num_samples is not None:
+ text = [txt for txt in text for _ in range(self.num_samples)]
+ text_obj = Text._from_tensordict(out.empty())
+ with text_obj.view(-1) as text_obj_flat:
+ text_obj_flat.prompt = text
+ text_obj_flat.response = response_text
+ text_obj_flat.full = self._cat_text(text, response_text)
+ out.set(self.text_key, text_obj)
+
+ tokens_obj = Tokens._from_tensordict(out.empty())
+ if self.pad_output:
+ prompt = tokens_prompt_padded
+ else:
+ prompt = _unpad_tensors(
+ tokens_prompt_padded, attention_mask_prompt_padded, as_nested=False
+ )
+ if tokens_obj.ndim == 2:
+ for i in range(self.num_samples):
+ tokens_obj[:, i].prompt = prompt
+ else:
+ tokens_obj.prompt = prompt
+ with tokens_obj.view(-1) as tokens_obj_flat:
+ if not self.pad_output:
+ tokens_obj_flat.response = tokens_response_unpadded
+ tokens_full_unpadded = _unpad_tensors(
+ tokens_full_padded, attention_mask_full_padded, as_nested=False
+ )
+ tokens_obj_flat.full = tokens_full_unpadded
+ else:
+ tokens_obj_flat.response = tokens_response_padded
+ tokens_obj_flat.full = tokens_full_padded
+ tokens_obj.padded = MetaData(self.pad_output)
+ out.set(self.tokens_key, tokens_obj)
+
+ masks_obj = Masks._from_tensordict(out.empty())
+ if out.ndim == 2:
+ attention_mask_full_padded = attention_mask_full_padded.unflatten(
+ 0, (-1, self.num_samples)
)
- out.set("logits", _unpad_tensors(logits, mask_sequences, as_nested=False))
+ if self.pad_output:
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
+ else:
+ if out.ndim == 2:
+ with tokens_obj.view(-1) as tokens_obj_flat, masks_obj.view(
+ -1
+ ) as masks_obj_flat:
+ attention_mask_full_unpadded = attention_mask_full_padded.flatten(
+ 0, 1
+ )
+ attention_mask_full_unpadded = _unpad_tensors(
+ attention_mask_full_unpadded.bool(),
+ attention_mask_full_padded.flatten(0, 1),
+ as_nested=False,
+ )
+ masks_obj_flat.all_attention_mask = attention_mask_full_unpadded
+ else:
+ attention_mask_full_unpadded = _unpad_tensors(
+ attention_mask_full_padded.bool(),
+ attention_mask_full_padded,
+ as_nested=False,
+ )
+ masks_obj.all_attention_mask = attention_mask_full_unpadded
+ masks_obj.all_assistant_mask = None
+ masks_obj.padded = MetaData(self.pad_output)
+ out.set(self.masks_key, masks_obj)
+
+ if self.return_log_probs:
+ log_probs_obj = LogProbs._from_tensordict(out.empty())
+ with log_probs_obj.view(-1) as log_probs_obj_flat:
+ # Unfortunate but we only have the log-probs for the new tokens, not for the prompt - to get that, we'd need to run the forward pass again
+ if self.pad_output:
+ log_probs_obj_flat.prompt = None
+ log_probs_obj_flat.response = log_probs
+ log_probs_obj_flat.full = None
+ else:
+ log_probs_unpadded = _unpad_tensors(
+ log_probs, attention_mask_response_padded, as_nested=False
+ )
+ log_probs_obj_flat.prompt = None
+ log_probs_obj_flat.response = log_probs_unpadded
+ log_probs_obj_flat.full = None
+ log_probs_obj.padded = MetaData(self.pad_output)
+ out.set(self.log_probs_key, log_probs_obj)
+
+ # Add logits to output if we're in a get_dist call
+ if self._in_get_dist_call:
+ if self.pad_output:
+ out.set("logits", logits)
+ else:
+ logits_full_unpadded = _unpad_tensors(
+ logits, attention_mask_full_padded, as_nested=False
+ )
+ out.set("logits", logits_full_unpadded)
+
return out
- def _from_transformers_generate_tokens(self, td, out, cfg=None):
+ def _cat_tensors(
+ self,
+ tokens: torch.Tensor | list[torch.Tensor],
+ response_tokens: torch.Tensor | list[torch.Tensor],
+ cast: torch.dtype | None = None,
+ ):
+ """Concatenate tokens and response tokens."""
+ if isinstance(tokens, list) or isinstance(response_tokens, list):
+ return [
+ self._cat_tensors(t, t_, cast=cast)
+ for t, t_ in _zip_strict(tokens, response_tokens)
+ ]
+ else:
+ result = torch.cat([tokens, response_tokens], dim=-1)
+ if cast is not None:
+ result = result.to(cast)
+ return result
+
+ def _logprobs_from_history_tokens(self, response_tokens, cfg, out):
+ """Compute log-probs from history tokens."""
pad_val = self.tokenizer.pad_token_id
- input_ids = td.get(
- self.token_key,
+ # unfortunately HF wants us to use padded tensors
+ tokens_full_padded = response_tokens.get(
+ "input_ids",
as_padded_tensor=True,
padding_side="left",
padding_value=pad_val,
)
- attention_mask = td.get(
- self.attention_mask_key,
+ if not isinstance(tokens_full_padded, torch.Tensor):
+ raise ValueError(
+ f"Expected Tensor for tokens_full_padded, got {type(tokens_full_padded)}"
+ )
+ attention_mask_full_padded = response_tokens.get(
+ "attention_mask",
as_padded_tensor=True,
padding_side="left",
padding_value=0,
)
- if attention_mask is None:
- attention_mask = (input_ids != pad_val).to(torch.int64)
+ if not isinstance(attention_mask_full_padded, torch.Tensor):
+ raise ValueError(
+ f"Expected Tensor for attention_mask_full_padded, got {type(attention_mask_full_padded)}"
+ )
+
if cfg is not None:
kwargs = copy(self.generate_kwargs)
kwargs["generation_config"] = cfg
else:
kwargs = self.generate_kwargs
- tokens_out = self.model.generate(
- input_ids=input_ids, attention_mask=attention_mask, **kwargs
+
+ tokens_out_struct = self.model(
+ tokens_full_padded, attention_mask=attention_mask_full_padded, **kwargs
)
- sequences = tokens_out["sequences"]
- sequences = sequences[:, input_ids.shape[-1] :]
- mask_sequences = sequences != pad_val
- sequences = _unpad_tensors(sequences, mask_sequences, as_nested=False)
- if self.return_log_probs:
- logits = tokens_out["logits"]
- logits = torch.stack(list(logits), 1)
- logits = _unpad_tensors(logits, mask_sequences, as_nested=False)
- log_probs, logits = self._log_probs_generate(
- sequences, logits, pad_val=pad_val
- )
- out.set(
- self.token_response_key,
- sequences,
+ (
+ log_probs_full_padded,
+ logits_full_padded,
+ ) = self._compute_log_probs_from_model_output(
+ tokens_out_struct,
+ tokens_full_padded,
+ attention_mask_full_padded,
+ pad_val,
)
- out.set(
- self.token_key, _unpad_tensors(input_ids, attention_mask, as_nested=False)
+
+ # Build output TensorClass objects
+ text_obj = Text._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
)
- out.set(
- self.attention_mask_key,
- _unpad_tensors(attention_mask, attention_mask, as_nested=False),
+ text_obj.prompt = None
+ text_obj.response = None
+ text_obj.full = None
+ out.set(self.text_key, text_obj)
+
+ all_assistant_mask_padded = response_tokens.get(
+ "assistant_masks",
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=0,
)
- if self.return_log_probs:
- out.set(
- self.log_prob_key,
- _unpad_tensors(log_probs, mask_sequences, as_nested=False),
+ if all_assistant_mask_padded is not None:
+ all_assistant_mask_padded = all_assistant_mask_padded.bool()
+ masks_obj = Masks._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if self.pad_output:
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
+ if all_assistant_mask_padded is not None:
+ masks_obj.all_assistant_mask = all_assistant_mask_padded
+ else:
+ masks_obj.all_attention_mask = _unpad_tensors(
+ attention_mask_full_padded.bool(),
+ attention_mask_full_padded,
+ as_nested=False,
)
- out.set("logits", _unpad_tensors(logits, mask_sequences, as_nested=False))
- return out
+ if all_assistant_mask_padded is not None:
+ masks_obj.all_assistant_mask = _unpad_tensors(
+ all_assistant_mask_padded,
+ attention_mask_full_padded,
+ as_nested=False,
+ )
+ masks_obj.padded = MetaData(self.pad_output)
+ out.set(self.masks_key, masks_obj)
- def _from_transformers_logprobs_text(self, td, out, cfg=None):
- pad_val = self.tokenizer.pad_token_id
+ tokens_obj = Tokens._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if self.pad_output:
+ tokens_obj.full = tokens_full_padded
+ else:
+ input_ids_full_unpadded = _unpad_tensors(
+ tokens_full_padded, attention_mask_full_padded, as_nested=False
+ )
+ tokens_obj.full = input_ids_full_unpadded
+ tokens_obj.response = None
+ tokens_obj.padded = MetaData(self.pad_output)
+ out.set(self.tokens_key, tokens_obj)
- prompt_txt = td.get(self.text_key)
- response_txt = td.get(self.text_response_key)
- if prompt_txt is None or response_txt is None:
- if prompt_txt is not None and response_txt is not None:
- raise ValueError(
- "No text or history provided to the TransformersWrapper. Either both are provided or none of them."
- )
- # Fallback on history parsing
- history = td.get(self.history_key)
- if history is None:
- raise ValueError(
- "No text or history provided to the TransformersWrapper."
- )
- tokenizer_kwargs = {}
- if self.chat_template_name is not None:
- tokenizer_kwargs.setdefault(
- "chat_template_name", self.chat_template_name
- )
- if self.chat_template is not None:
- tokenizer_kwargs.setdefault("chat_template", self.chat_template)
- tokenizer_kwargs.setdefault("add_generation_prompt", False)
- response_txt = history.apply_chat_template(
- tokenizer=self.tokenizer, **tokenizer_kwargs
+ log_probs_obj = LogProbs._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if self.pad_output:
+ log_probs_obj.full = log_probs_full_padded
+ else:
+ log_probs_full_unpadded = _unpad_tensors(
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
)
- if isinstance(response_txt, list):
- prompt_txt = ["" for _ in response_txt]
+ log_probs_obj.full = log_probs_full_unpadded
+ log_probs_obj.response = None
+ log_probs_obj.padded = MetaData(self.pad_output)
+ out.set(self.log_probs_key, log_probs_obj)
+
+ # Add logits to output if we're in a get_dist call
+ if self._in_get_dist_call:
+ if self.pad_output:
+ out.set("logits", logits_full_padded)
else:
- prompt_txt = ""
-
- if not isinstance(prompt_txt, (list, str)):
- prompt_txt = prompt_txt.tolist()
- if not isinstance(response_txt, (list, str)):
- response_txt = response_txt.tolist()
- total_txt = [x + y for x, y in _zip_strict(prompt_txt, response_txt)]
- total_tokens_in = self.tokenizer(total_txt, **self.tokenizer_kwargs)
- prompt_tokens_in = self.tokenizer(prompt_txt, **self.tokenizer_kwargs)
- if self._device is not None:
- total_tokens_in = total_tokens_in.to(self._device)
- prompt_tokens_in = prompt_tokens_in.to(self._device)
+ logits_full_unpadded = _unpad_tensors(
+ logits_full_padded, attention_mask_full_padded, as_nested=False
+ )
+ out.set("logits", logits_full_unpadded)
+
+ return out
+
+ def _from_transformers_generate_text(self, td, cfg, out) -> TensorDictBase:
+ """Generate text from text input."""
+ # Validate input
+ if self.input_key not in td:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for text input mode, "
+ f"but found keys: {list(td.keys())}"
+ )
- total_input_ids = total_tokens_in["input_ids"]
- total_attention_mask = total_tokens_in["attention_mask"]
- prompt_input_ids = prompt_tokens_in["input_ids"]
- prompt_attention_mask = prompt_tokens_in["attention_mask"]
+ text = td.get(self.input_key)
+ if text is None:
+ raise ValueError(f"Expected '{self.input_key}' key for text input mode")
+ if isinstance(text, NonTensorStack):
+ text = text.tolist()
+ if not isinstance(text, list):
+ raise ValueError(f"Expected list of text for text input, got {type(text)}")
+ return self._generate_from_text(text, cfg, out)
+
+ def _from_transformers_logprobs_text(self, td, cfg, out):
+ """Compute log-probs from text input."""
+ # Validate input
+ if self.input_key not in td:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for text input mode, "
+ f"but found keys: {list(td.keys())}"
+ )
+
+ text = td.get(self.input_key)
+ if isinstance(text, NonTensorStack):
+ text = text.tolist()
+ if text is None:
+ raise ValueError(f"Expected '{self.input_key}' key for text input mode")
+ if not isinstance(text, list):
+ raise ValueError(f"Expected list of text for text input, got {type(text)}")
+ # Tokenize the text
+ if self.tokenizer is None:
+ raise ValueError(
+ "Tokenizer is required for log-probs computation with text input"
+ )
+
+ # Convert text to list format
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list):
+ text = text.tolist()
+
+ # Tokenize the text
+ tokenizer_kwargs = dict(self.tokenizer_kwargs)
+ with torch.device(
+ self._device
+ ) if self._device is not None else contextlib.nullcontext():
+ tokens_in = self.tokenizer(text, **tokenizer_kwargs)
if cfg is not None:
kwargs = copy(self.generate_kwargs)
@@ -475,138 +1140,425 @@ def _from_transformers_logprobs_text(self, td, out, cfg=None):
else:
kwargs = self.generate_kwargs
- total_tokens_out = self.model(
- total_input_ids, attention_mask=total_attention_mask, **kwargs
+ # We are going to map this tokens_in to a tensordict to facilitate the padding in case we need it
+ tokens_in = (
+ TensorDict(batch_size=len(tokens_in["input_ids"]))
+ .to_lazystack(0)
+ .update(dict(tokens_in))
+ )
+ input_ids_full_padded = tokens_in.get(
+ "input_ids",
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=self.padding_value,
+ )
+ attention_mask_full_padded = tokens_in.get(
+ "attention_mask",
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=0,
+ )
+
+ tokens_out_struct = self.model(
+ input_ids_full_padded, attention_mask=attention_mask_full_padded, **kwargs
)
- total_input_ids = _unpad_tensors(
- total_input_ids, total_attention_mask, as_nested=False
+ # Compute log-probs for the input tokens
+ (
+ log_probs_full_padded,
+ logits_full_padded,
+ ) = self._compute_log_probs_from_model_output(
+ tokens_out_struct,
+ input_ids_full_padded,
+ attention_mask_full_padded,
+ self.tokenizer.pad_token_id,
)
- prompt_input_ids = _unpad_tensors(
- prompt_input_ids, prompt_attention_mask, as_nested=False
+
+ # Build output TensorClass objects
+ text_obj = Text._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
)
- sequences = [
- _total_input_ids[_prompt_input_ids.shape[-1] :]
- if _prompt_input_ids.shape[-1] > 0
- else _total_input_ids
- for _total_input_ids, _prompt_input_ids in zip(
- total_input_ids, prompt_input_ids
+ text_obj.prompt = None
+ text_obj.response = None
+ text_obj.full = text
+ out.set(self.text_key, text_obj)
+
+ tokens_obj = Tokens._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if self.pad_output:
+ tokens_obj.full = input_ids_full_padded
+ else:
+ input_ids_full_unpadded = _unpad_tensors(
+ input_ids_full_padded, attention_mask_full_padded, as_nested=False
)
- ]
- # response_attention_mask = total_attention_mask[
- # :, prompt_attention_mask.shape[-1] :
- # ]
- log_probs, logits = self._log_probs_from_logits(
- total_tokens_out, sequences, pad_val=pad_val
+ tokens_obj.full = input_ids_full_unpadded
+ tokens_obj.response = None
+ tokens_obj.padded = MetaData(self.pad_output)
+ out.set(self.tokens_key, tokens_obj)
+
+ masks_obj = Masks._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if self.pad_output:
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
+ masks_obj.all_assistant_mask = td.get(("masks", "all_assistant_mask"))
+ else:
+ attention_mask_full_unpadded = _unpad_tensors(
+ attention_mask_full_padded.bool(),
+ attention_mask_full_padded,
+ as_nested=False,
+ )
+ masks_obj.all_attention_mask = attention_mask_full_unpadded
+ masks_obj.all_assistant_mask = td.get(
+ ("masks", "all_assistant_mask"), as_list=True
+ )
+ masks_obj.padded = MetaData(self.pad_output)
+ out.set(self.masks_key, masks_obj)
+
+ log_probs_obj = LogProbs._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
)
+ if self.pad_output:
+ log_probs_obj.full = log_probs_full_padded
+ else:
+ log_probs_full_unpadded = _unpad_tensors(
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
+ )
+ log_probs_obj.full = log_probs_full_unpadded
+ log_probs_obj.response = None
+ log_probs_obj.padded = MetaData(self.pad_output)
+ out.set(self.log_probs_key, log_probs_obj)
+
+ # Add logits to output if we're in a get_dist call
+ if self._in_get_dist_call:
+ if self.pad_output:
+ out.set("logits", logits_full_padded)
+ else:
+ logits_full_unpadded = _unpad_tensors(
+ logits_full_padded, attention_mask_full_padded, as_nested=False
+ )
+ out.set("logits", logits_full_unpadded)
- out.set("logits", logits)
- out.set(self.log_prob_key, log_probs)
- out.set(self.token_response_key, sequences)
return out
- def _from_transformers_logprobs_tokens(self, td, out, cfg=None):
+ def _from_transformers_generate_tokens(
+ self, td: TensorDictBase, cfg: dict | None, out: TensorDictBase
+ ) -> TensorDictBase:
+ """Generate text from tokens input."""
+ # Validate input
+ if self.input_key not in td:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for tokens input mode, "
+ f"but found keys: {list(td.keys())}"
+ )
+
pad_val = self.tokenizer.pad_token_id
- prompt_input_ids = td.get(
- self.token_key,
- as_list=True,
+ input_ids_prompt_padded = td.get(
+ self.input_key,
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=pad_val,
)
- response_input_ids = td.get(
- self.token_response_key,
- as_list=True,
+ attention_mask_prompt_padded = td.get(
+ ("masks", "all_attention_mask"),
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=False,
)
- # prompt_attention_mask = td.get(
- # self.attention_mask_key,
- # as_list=True,
- # )
-
- total_input_ids = [
- torch.cat([_prompt_input_ids, _response_input_ids], -1)
- for _prompt_input_ids, _response_input_ids in _zip_strict(
- prompt_input_ids, response_input_ids
+ if attention_mask_prompt_padded is None:
+ attention_mask_prompt_padded = td.get(
+ self.attention_mask_key,
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=False,
)
- ]
- total_input_ids = pad_sequence(
- total_input_ids,
- padding_value=pad_val,
- padding_side="left",
- batch_first=True,
+ if attention_mask_prompt_padded is None:
+ attention_mask_prompt_padded = input_ids_prompt_padded != pad_val
+ return self._generate_from_tokens(
+ input_ids_prompt_padded, attention_mask_prompt_padded, cfg, out
)
- total_attention_mask = (total_input_ids != pad_val).to(torch.int64)
-
- # if prompt_attention_mask is None:
- # prompt_attention_mask = [
- # (_prompt_input_ids != pad_val).to(torch.int64)
- # for _prompt_input_ids in prompt_input_ids
- # ]
+ def _generate_from_tokens(
+ self,
+ tokens_prompt_padded: torch.Tensor,
+ attention_mask_prompt_padded: torch.Tensor,
+ cfg: dict | None,
+ out: TensorDictBase,
+ ) -> TensorDictBase:
if cfg is not None:
kwargs = copy(self.generate_kwargs)
kwargs["generation_config"] = cfg
else:
kwargs = self.generate_kwargs
- total_tokens_out = self.model(
- total_input_ids, attention_mask=total_attention_mask, **kwargs
+ tokens_out_struct = self.model.generate(
+ input_ids=tokens_prompt_padded,
+ attention_mask=attention_mask_prompt_padded,
+ **kwargs,
)
- log_probs, logits = self._log_probs_from_logits(
- total_tokens_out, response_input_ids, pad_val=-100
+ tokens_full_padded = tokens_out_struct["sequences"]
+ tokens_response_padded = tokens_full_padded[:, tokens_prompt_padded.shape[-1] :]
+ pad_val = getattr(self.tokenizer, "pad_token_id", None)
+ if pad_val is None:
+ pad_val = self.padding_value
+ attention_mask_reponse_padded = tokens_response_padded != pad_val
+ attention_mask_full_padded = tokens_full_padded != pad_val
+ tokens_response_unpadded = _unpad_tensors(
+ tokens_response_padded, attention_mask_reponse_padded, as_nested=False
)
- # for i in range(log_probs.size(0)):
- # assert log_probs[i].shape[-1] == response_input_ids[i].shape[-1]
- out.set("logits", logits)
- out.set(self.log_prob_key, log_probs)
- return out
+ if self.return_log_probs:
+ # These are only for the new tokens, not for the prompt - to get that, we'd need to run the forward pass again
+ logits_response_padded = tokens_out_struct["logits"]
+ logits_response_padded = torch.stack(list(logits_response_padded), 1)
+ (
+ log_probs_response_padded,
+ logits_response_padded,
+ ) = self._log_probs_generate(
+ tokens_response_padded,
+ logits_response_padded,
+ pad_val=pad_val,
+ pad=False,
+ )
- @classmethod
- def _log_probs_from_logits(cls, total_tokens_out, response_input_ids, pad_val=-100):
- response_input_ids = pad_sequence(
- response_input_ids,
- padding_value=pad_val,
- batch_first=True,
- padding_side="left",
+ response_text = self.tokenizer.batch_decode(
+ tokens_response_unpadded, skip_special_tokens=False
)
- pad_mask = response_input_ids != pad_val
- logits = total_tokens_out["logits"]
- # logits = logits.log_softmax(dim=-1)
- if logits.shape[-2] != response_input_ids.shape[-1]:
- logits = logits[..., -response_input_ids.shape[-1] - 1 : -1, :]
+ # Build output TensorClass objects
+ text_obj = Text._from_tensordict(out.empty())
+ text_obj.prompt = None # We don't have text in tokens mode
+ with text_obj.view(-1) as text_obj_flat:
+ text_obj_flat.response = response_text
+ text_obj.full = None # we don't have text in tokens mode so no all_text either
+ out.set(self.text_key, text_obj)
- td = TensorDict(
- logits=logits, response_input_ids=response_input_ids
- ).auto_batch_size_()
- with td.flatten() as tdflat:
- tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
- tdflat["logits"],
- tdflat["response_input_ids"],
- reduce=False,
- ignore_index=pad_val,
+ tokens_obj = Tokens._from_tensordict(out.empty())
+ if not self.pad_output:
+ input_ids_prompt_unpadded = _unpad_tensors(
+ tokens_prompt_padded,
+ attention_mask_prompt_padded,
+ as_nested=False,
)
- log_probs = td["log_probs"]
+ if self.num_samples is not None:
+ # replicate tokens
+ for i in range(self.num_samples):
+ tokens_obj[:, i].prompt = (
+ input_ids_prompt_unpadded
+ if not self.pad_output
+ else tokens_prompt_padded
+ )
+ else:
+ tokens_obj.prompt = (
+ input_ids_prompt_unpadded
+ if not self.pad_output
+ else tokens_prompt_padded
+ )
+ with tokens_obj.view(-1) as tokens_obj_flat:
+ if self.pad_output:
+ tokens_obj_flat.response = tokens_response_padded
+ tokens_obj_flat.full = tokens_full_padded
+ else:
+ tokens_obj_flat.response = tokens_response_unpadded
+ tokens_full_unpadded = _unpad_tensors(
+ tokens_full_padded, attention_mask_full_padded, as_nested=False
+ )
+ tokens_obj_flat.full = tokens_full_unpadded
+ tokens_obj.padded = MetaData(self.pad_output)
+ out.set(self.tokens_key, tokens_obj)
+
+ masks_obj = Masks._from_tensordict(out.empty())
+ if out.ndim == 2:
+ attention_mask_full_padded = attention_mask_full_padded.unflatten(
+ 0, (-1, self.num_samples)
+ )
+ if self.pad_output:
+ # Get "real" attention masks
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
+ else:
+ # Get "real" attention masks
+ # We can use select to avoid batch-size problems
+ _td = torch.ones_like(
+ out.select(("tokens", "full"))
+ .copy()
+ .rename_key_(("tokens", "full"), "all_attention_mask")
+ ).bool()
+ del _td["tokens"]
+ masks_obj.update(_td)
+ masks_obj.all_assistant_mask = None
+ masks_obj.padded = MetaData(self.pad_output)
+ out.set(self.masks_key, masks_obj)
- # Recover the list
- log_probs = _unpad_tensors(log_probs, pad_mask)
- logits = _unpad_tensors(logits, pad_mask)
- return log_probs, logits
+ if self.return_log_probs:
+ log_probs_obj = LogProbs._from_tensordict(out.empty())
+ if self.num_samples is None:
+ if self.pad_output:
+ log_probs_obj.response = log_probs_response_padded
+ else:
+ log_probs_response_unpadded = _unpad_tensors(
+ log_probs_response_padded,
+ attention_mask_reponse_padded,
+ as_nested=False,
+ )
+ log_probs_obj.response = log_probs_response_unpadded
+ else:
+ with log_probs_obj.view(-1) as log_probs_obj_flat:
+ if self.pad_output:
+ log_probs_obj_flat.response = log_probs_response_padded
+ else:
+ log_probs_response_unpadded = _unpad_tensors(
+ log_probs_response_padded,
+ attention_mask_reponse_padded,
+ as_nested=False,
+ )
+ log_probs_obj_flat.response = log_probs_response_unpadded
+ log_probs_obj.padded = MetaData(self.pad_output)
+ out.set(self.log_probs_key, log_probs_obj)
- @classmethod
- def _log_probs_generate(cls, sequences, logits, pad_val=-100):
- tokens = pad_sequence(
- sequences,
- padding_value=pad_val,
- batch_first=True,
+ return out
+
+ def _from_transformers_logprobs_tokens(
+ self, td: TensorDictBase, cfg: dict | None, out: TensorDictBase
+ ) -> TensorDictBase:
+ """Compute log-probs from tokens input."""
+ # Validate input
+ if self.input_key not in td:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for tokens input mode, "
+ f"but found keys: {list(td.keys(isinstance(self.input_key, tuple)))}"
+ )
+
+ pad_val = self.tokenizer.pad_token_id
+
+ input_ids_full_padded = td.get(
+ self.input_key,
+ as_padded_tensor=True,
padding_side="left",
+ padding_value=pad_val,
)
- logits = pad_sequence(
- logits,
- padding_value=0.0,
- batch_first=True,
+ # Attention mask: try first the regular entry, then the key provided in the constructor, finally fallback on eager attention mask
+ attention_mask_full_padded = td.get(
+ ("masks", "all_attention_mask"),
+ as_padded_tensor=True,
padding_side="left",
+ padding_value=False,
+ )
+ if attention_mask_full_padded is None:
+ attention_mask_full_padded = td.get(
+ self.attention_mask_key,
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=False,
+ )
+ if attention_mask_full_padded is None:
+ attention_mask_full_padded = input_ids_full_padded != pad_val
+
+ if cfg is not None:
+ kwargs = copy(self.generate_kwargs)
+ kwargs["generation_config"] = cfg
+ else:
+ kwargs = self.generate_kwargs
+
+ tokens_out_struct = self.model(
+ input_ids_full_padded, attention_mask=attention_mask_full_padded, **kwargs
)
+ # Compute log-probs for the input tokens
+ (
+ log_probs_full_padded,
+ logits_full_padded,
+ ) = self._compute_log_probs_from_model_output(
+ tokens_out_struct,
+ input_ids_full_padded,
+ attention_mask_full_padded,
+ self.tokenizer.pad_token_id,
+ )
+
+ # Build output TensorClass objects
+ text_obj = Text._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ text_obj.prompt = None
+ text_obj.response = None
+ text_obj.full = None
+ out.set(self.text_key, text_obj)
+
+ tokens_obj = Tokens._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if not self.pad_output:
+ input_ids_full_unpadded = _unpad_tensors(
+ input_ids_full_padded, attention_mask_full_padded, as_nested=False
+ )
+ tokens_obj.full = input_ids_full_unpadded
+ else:
+ tokens_obj.full = input_ids_full_padded
+ tokens_obj.response = None
+ tokens_obj.padded = MetaData(self.pad_output)
+ out.set(self.tokens_key, tokens_obj)
+
+ masks_obj = Masks._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if self.pad_output:
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
+ masks_obj.all_assistant_mask = td.get(("masks", "all_assistant_mask"))
+ else:
+ masks_obj.all_attention_mask = _unpad_tensors(
+ attention_mask_full_padded.bool(),
+ attention_mask_full_padded,
+ as_nested=False,
+ )
+ masks_obj.all_assistant_mask = td.get(
+ ("masks", "all_assistant_mask"), as_list=True
+ )
+
+ masks_obj.padded = MetaData(self.pad_output)
+ out.set(self.masks_key, masks_obj)
+
+ log_probs_obj = LogProbs._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if self.pad_output:
+ log_probs_obj.full = log_probs_full_padded
+ else:
+ log_probs_full_unpadded = _unpad_tensors(
+ log_probs_full_padded, attention_mask_full_padded, as_nested=False
+ )
+ log_probs_obj.full = log_probs_full_unpadded
+ log_probs_obj.response = None
+ log_probs_obj.padded = MetaData(self.pad_output)
+ out.set(self.log_probs_key, log_probs_obj)
+
+ # Add logits to output if we're in a get_dist call
+ if self._in_get_dist_call:
+ if self.pad_output:
+ out.set("logits", logits_full_padded)
+ else:
+ logits_full_unpadded = _unpad_tensors(
+ logits_full_padded, attention_mask_full_padded, as_nested=False
+ )
+ out.set("logits", logits_full_unpadded)
+ return out
+
+ @classmethod
+ def _log_probs_generate(cls, tokens, logits, pad_val=-100, pad: bool = True):
+ if pad:
+ tokens = pad_sequence(
+ tokens,
+ padding_value=pad_val,
+ batch_first=True,
+ padding_side="left",
+ )
+ logits = pad_sequence(
+ logits,
+ padding_value=0.0,
+ batch_first=True,
+ padding_side="left",
+ )
+
# logits = logits.log_softmax(dim=-1)
# log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
td = TensorDict(logits=logits, tokens=tokens).auto_batch_size_()
@@ -614,5 +1566,250 @@ def _log_probs_generate(cls, sequences, logits, pad_val=-100):
tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
tdflat["logits"], tdflat["tokens"], reduce=False, ignore_index=pad_val
)
+ td["log_probs"][:, 0] = 0
log_probs = td["log_probs"]
return log_probs, logits
+
+ def _compute_log_probs_from_model_output(
+ self, model_output, input_ids, attention_mask, pad_val
+ ):
+ """Compute log-probs from model output without modifying original tensors.
+
+ Args:
+ model_output: Output from the model containing logits
+ input_ids: Original input token ids
+ attention_mask: Original attention mask
+ pad_val: Padding token value to ignore in loss computation
+
+ Returns:
+ tuple: (log_probs, shifted_logits) where log_probs are the computed log probabilities
+ and shifted_logits are the logits shifted to align with tokens
+ """
+ logits = model_output["logits"]
+
+ # Create shifted versions for log-prob computation without modifying originals
+ shifted_logits = logits[:, :-1, :]
+ # shifted_logits = shifted_logits - shifted_logits.logsumexp(dim=-1, keepdim=True)
+ shifted_logits = torch.cat(
+ [torch.zeros_like(shifted_logits[:, :1]), shifted_logits], 1
+ )
+
+ shifted_input_ids = input_ids[:, 1:]
+ shifted_input_ids = torch.cat(
+ [torch.zeros_like(shifted_input_ids[:, :1]), shifted_input_ids], 1
+ )
+
+ # Check that the shape is correct
+ if shifted_logits.shape[-2] != shifted_input_ids.shape[-1]:
+ raise ValueError(
+ f"The logits shape {shifted_logits.shape} does not match the input ids shape {shifted_input_ids.shape}"
+ )
+
+ # Compute log-probs
+ td = TensorDict(
+ logits=shifted_logits, tokens=shifted_input_ids
+ ).auto_batch_size_()
+ with td.flatten() as tdflat:
+ tdflat["log_probs"] = -torch.nn.functional.cross_entropy(
+ tdflat["logits"],
+ tdflat["tokens"],
+ reduce=False,
+ ignore_index=pad_val,
+ )
+ # For consistency with vllm, we set the log-probs of the first token to 0
+ # However, the first element may not be the first - we want the first of the attention mask,
+ # i.e, the first element that is true on the left
+ attention_mask = attention_mask.bool()
+ attention_mask_first_left = ~attention_mask[:, :-1] & attention_mask[:, 1:]
+ attention_mask_first_left = torch.cat(
+ [
+ torch.zeros_like(attention_mask_first_left[..., :1]),
+ attention_mask_first_left,
+ ],
+ -1,
+ )
+ attention_mask_first_left[~(attention_mask_first_left.any(-1)), 0] = True
+ assert attention_mask_first_left.any(-1).all()
+ attention_mask_first_left = attention_mask_first_left | ~attention_mask
+ td["log_probs"][attention_mask_first_left] = 0
+
+ return td["log_probs"], shifted_logits
+
+ def get_dist(
+ self,
+ tensordict: TensorDictBase,
+ tensordict_out: TensorDictBase | None = None,
+ logits_key: NestedKey = "logits",
+ mask_key: NestedKey | None = None,
+ as_padded_tensor: bool | None = None,
+ as_nested_tensor: bool | None = None,
+ padding_value: float | None = None,
+ padding_side: str = "right",
+ layout: torch.layout | None = None,
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution from logits/log-probs with optional masking.
+
+ This method enables logits computation for distribution creation.
+ """
+ self._in_get_dist_call = True
+ self.out_keys += ["logits"]
+ try:
+ return super().get_dist(
+ tensordict,
+ tensordict_out,
+ logits_key,
+ mask_key,
+ as_padded_tensor,
+ as_nested_tensor,
+ padding_value,
+ padding_side,
+ layout,
+ **kwargs,
+ )
+ finally:
+ self._in_get_dist_call = False
+ self.out_keys.remove("logits")
+
+ def _get_dist_with_prompt_mask(
+ self,
+ tensordict: TensorDictBase,
+ tokens_key: NestedKey = ("tokens", "prompt"),
+ logits_key: NestedKey = "logits",
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution masked to only include response tokens (exclude prompt).
+
+ This method enables logits computation for distribution creation.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ self._in_get_dist_call = True
+ self.out_keys += ["logits"]
+ try:
+ return super()._get_dist_with_prompt_mask(
+ tensordict,
+ tokens_key,
+ logits_key,
+ assistant_mask_key,
+ attention_mask_key,
+ **kwargs,
+ )
+ finally:
+ self._in_get_dist_call = False
+ self.out_keys.remove("logits")
+
+ def _get_dist_with_assistant_mask(
+ self,
+ tensordict: TensorDictBase,
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
+ logits_key: NestedKey = "logits",
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution masked to only include assistant tokens.
+
+ This method enables logits computation for distribution creation.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ self._in_get_dist_call = True
+ self.out_keys += ["logits"]
+ try:
+ return super()._get_dist_with_assistant_mask(
+ tensordict, assistant_mask_key, logits_key, **kwargs
+ )
+ finally:
+ self._in_get_dist_call = False
+ self.out_keys.remove("logits")
+
+ def _get_dist_with_attention_mask(
+ self,
+ tensordict: TensorDictBase,
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
+ logits_key: NestedKey = "logits",
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution masked using attention mask.
+
+ This method enables logits computation for distribution creation.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ self._in_get_dist_call = True
+ self.out_keys += ["logits"]
+ try:
+ return super()._get_dist_with_attention_mask(
+ tensordict, attention_mask_key, logits_key, **kwargs
+ )
+ finally:
+ self._in_get_dist_call = False
+ self.out_keys.remove("logits")
+
+ def _get_dist_with_custom_mask(
+ self,
+ tensordict: TensorDictBase,
+ mask: torch.Tensor,
+ logits_key: NestedKey = "logits",
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution with custom mask.
+
+ This method enables logits computation for distribution creation.
+ """
+ self._in_get_dist_call = True
+ self.out_keys += ["logits"]
+ try:
+ return super()._get_dist_with_custom_mask(
+ tensordict, mask, logits_key, **kwargs
+ )
+ finally:
+ self._in_get_dist_call = False
+ self.out_keys.remove("logits")
+
+ # Convenience methods for common LLM training scenarios
+ def _get_sft_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
+ """Get distribution suitable for SFT loss (response tokens only).
+
+ This method enables logits computation for distribution creation.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ self._in_get_dist_call = True
+ self.out_keys += ["logits"]
+ try:
+ return super()._get_sft_dist(tensordict, **kwargs)
+ finally:
+ self._in_get_dist_call = False
+ self.out_keys.remove("logits")
+
+ def _get_rlhf_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
+ """Get distribution suitable for RLHF loss (assistant tokens only).
+
+ This method enables logits computation for distribution creation.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ self._in_get_dist_call = True
+ self.out_keys += ["logits"]
+ try:
+ return super()._get_rlhf_dist(tensordict, **kwargs)
+ finally:
+ self._in_get_dist_call = False
+ self.out_keys.remove("logits")
+
+ def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
+ """Get distribution suitable for generic losses (all tokens).
+
+ This method enables logits computation for distribution creation.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ self._in_get_dist_call = True
+ self.out_keys += ["logits"]
+ try:
+ return super()._get_generic_dist(tensordict, **kwargs)
+ finally:
+ self._in_get_dist_call = False
+ self.out_keys.remove("logits")
diff --git a/torchrl/modules/llm/policies/vllm_wrapper.py b/torchrl/modules/llm/policies/vllm_wrapper.py
index 7f3625a46fc..9133b98c3fe 100644
--- a/torchrl/modules/llm/policies/vllm_wrapper.py
+++ b/torchrl/modules/llm/policies/vllm_wrapper.py
@@ -5,164 +5,279 @@
from __future__ import annotations
import collections
-from typing import Literal
+import warnings
+from typing import Any, Literal
import torch
from tensordict import (
lazy_stack,
- LazyStackedTensorDict,
- maybe_dense_stack,
- NestedKey,
+ MetaData,
+ NonTensorStack,
+ set_list_to_stack,
TensorDict,
TensorDictBase,
)
-from tensordict.tensorclass import from_dataclass, NonTensorStack, TensorClass
-from tensordict.utils import _zip_strict, expand_as_right
+from tensordict.tensorclass import from_dataclass, TensorClass
+from tensordict.utils import _zip_strict, NestedKey
+from torch import distributions as D
+from torch.nn.utils.rnn import pad_sequence
from torchrl.envs.utils import _classproperty
+from torchrl.modules.llm.policies.common import (
+ ChatHistory,
+ LLMWrapperBase,
+ LogProbs,
+ Masks,
+ Text,
+ Tokens,
+)
+from torchrl.modules.utils.utils import _unpad_tensors
-from torchrl.modules.llm.policies.common import CategoricalSequential
-
+# Type imports
+try:
+ import transformers
+ import vllm
+ from vllm.outputs import RequestOutput
+ from vllm.sampling_params import SamplingParams
+except ImportError:
+ vllm = None
+ transformers = None
+ SamplingParams = Any # type: ignore
+ RequestOutput = Any # type: ignore
-class vLLMWrapper(CategoricalSequential):
- """A wrapper class for vLLM models, providing a consistent interface for text generation and log probability computation, similar to the Hugging Face Transformers interface.
- This class allows for handling both text and token inputs, enabling text generation and log probability
- computation based on the specified configuration.
+class vLLMWrapper(LLMWrapperBase):
+ """A wrapper class for vLLM models, providing a consistent interface for text generation and log probability computation.
- .. note:: The default arguments of the `vLLMWrapper` class are set to make it easy to run this backend with
- the :class:`~torchrl.envs.custom.llm.LLMEnv` class.
+ This class is a subclass of :class:`~torchrl.modules.llm.policies.LLMWrapperBase` and provides a unified API for handling different input
+ modalities (history, text, tokens) with consistent output structure using :class:`~tensordict.TensorClass` objects.
Args:
- model (vllm.LLM): The vLLM model to wrap.
+ model (vllm.LLM | str): The vLLM model to wrap. If a string, it will be passed to `vllm.LLM`.
Keyword Args:
- return_log_probs (bool | None, optional): Whether to return log probabilities of the generated tokens.
+ tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | str | None, optional): The tokenizer to use for encoding and decoding text.
+ If `None`, the tokenizer associated with the model will be used. If a string, it will be passed to `transformers.AutoTokenizer.from_pretrained`.
Defaults to `None`.
- tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | None, optional): The tokenizer to use for
- encoding and decoding text. If `None`, the tokenizer associated with the model will be used. Defaults to
- `None`.
- from_text (bool, optional): Indicates whether the input is in text format. If `True`, the input is expected to
- be text that will be tokenized. If `False`, the input is expected to be token sequences. Defaults to `True`.
-
- .. note:: If `from_text` is `True`, the input text can be provided in the `"text"` key or in the `"history"` key.
- If using the `"history"` key, the history will be parsed from a :class:`~torchrl.data.llm.History` object to a
- text string using the tokenizer.
-
- device (torch.device | None, optional): The device to use for computation. If `None`, the default device will
- be used. Defaults to `None`.
- generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on
- the input. If `False`, only log probabilities will be computed for the response tokens/text. Defaults to `True`.
- generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. These
- arguments can control aspects of the generation process, such as temperature and top-k sampling. Defaults
- to `None`.
-
- .. note:: Sampling params can be overwritten at runtime using the kwargs of the forward method.
-
- tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. These arguments can
- control aspects of the tokenization process, such as padding and truncation. Defaults to `None`.
- pad_output (bool, optional): Whether to pad the output sequences to a uniform length. If `True`, the output
- sequences will be padded and represented as tensors. If `False`, lists of tokens will be used without
- padding. Defaults to `False`.
-
- .. warning:: The default value of `pad_output` differs from :func:`~torchrl.modules.TransformersWrapper`
- which does not handle non-padded inputs.
-
- inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place
- operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be
- created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will
- conserve type, batch-size, and device). Defaults to `True` when generating a single sample, `False`
- otherwise.
-
- chat_template_name (str | None, optional): The name of the chat template to use for the history. Defaults to `None`.
- chat_template (str | None, optional): The chat template to use for the history. Defaults to `None`.
-
- .. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also
- required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper
- tokenization and padding.
+ input_mode (str, optional): The input modality to use. Must be one of `"history"`, `"text"`, or `"tokens"`. Defaults to `"history"`.
+ input_key (str | None, optional): The key for the input data. If `None`, defaults to
+ - `("history", "prompt")` for `"history"` when `generate=True`, `("history", "full")` for `"history"` when `generate=False`
+ - `("text", "prompt")` for `"text"` when `generate=True`, `("text", "full")` for `"text"` when `generate=False`
+ - `("tokens", "prompt")` for `"tokens"` when `generate=True`, `("tokens", "full")` for `"tokens"` when `generate=False`
+ attention_mask_key (str, optional): The key for attention masks (used in `"tokens"` mode). Defaults to `"attention_mask"`.
+
+ .. warning:: This argument is under development and may change in the future.
+
+ generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on the input.
+ If `False`, only log probabilities will be computed. Defaults to `True`.
+ return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `True`.
+ generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. Defaults to `None`.
+ tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. Defaults to `None`.
+ pad_output (bool, optional): Whether to pad the output sequences to a uniform length. Defaults to `False`.
+ inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place operations. Defaults to `True`.
+ device (torch.device | None, optional): The device to use for computation. Defaults to `None`.
+ layout (torch.layout | None, optional): The layout to use for the output tensors when `pad_output=False`. Defaults to `torch.strided`.
+ chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when applying the chat template to the history.
+ Defaults to `None`. For `input_mode="history"` only.
+ chat_template (str | None, optional): The chat template to use when applying the chat template to the history. Defaults to `None`.
+ For `input_mode="history"` only.
+ num_samples (int | None, optional): The number of samples to generate. Defaults to `None` (one sample, and no batch-dimension for it).
+ Can also be set via the `generate_kwargs["n"] = value` argument.
+ log_probs_key (NestedKey | None, optional): The key for the log probabilities :class:`~torchrl.modules.llm.policies.LogProbs` object. Defaults to `"log_probs"`.
+ text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`.
+ tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`.
+ masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`.
+ history_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.ChatHistory` object. Defaults to `"history"`.
Input Keys:
-
- - If `from_text` is `True`:
-
- - `"text"`: The input text to be tokenized.
- - `"text_response"`: the response text (if `generate=False` as the log probabilities are computed for the response.)
-
- - If `from_text` is `False`:
-
- - "tokens": The input token sequences.
- - "attention_mask": The attention mask for the tokens.
- - "tokens_response": The response token sequences (if `generate=False` as the log probabilities are
- computed for the response.)
+ The input key depends on both `input_mode` and `generate`:
+ - If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`)
+ - If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`)
+ - If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`)
+ - If `input_mode="text"` and `generate=False`: `input_key` (defaults to `("text", "full")`)
+ - If `input_mode="tokens"` and `generate=True`: `input_key` (defaults to `("tokens", "prompt")`)
+ - If `input_mode="tokens"` and `generate=False`: `input_key` (defaults to `("tokens", "full")`)
Output Keys:
-
- - `"tokens_response"`: The generated token sequences.
- - `"log_probs"`: The log probabilities of the generated tokens (if `return_log_probs` is `True`).
- - `"text_response"`: The generated text (if `from_text` is `True` and `generate` is `True`).
+ The output keys are automatically determined based on the input_mode:
+ - **Tokens**: Always returned (`tokens_key`, defaults to `"tokens"`)
+ - **Text**: Returned for `"text"` and `"history"` modes (`text_key`, defaults to `"text"`)
+ - **History**: Returned only for `"history"` mode (`history_key`, defaults to `"history"`)
+ - **Masks**: Always returned (`masks_key`, defaults to `"masks"`)
+ - **Log Probs**: Returned when `return_log_probs=True` (`log_probs_key`, defaults to `"log_probs"`)
+
+ Example output structure for `input_mode="history"`:
+ ```
+ TensorDict(
+ text=Text(prompt=..., response=..., full=...),
+ masks=Masks(all_attention_mask=..., all_assistant_mask=...),
+ tokens=Tokens(prompt=..., response=..., full=...),
+ log_probs=LogProbs(prompt=..., response=..., full=...),
+ history=ChatHistory(prompt=..., response=..., full=...)
+ )
+ ```
Example:
>>> from vllm import LLM
>>> from transformers import AutoTokenizer
+ >>> from torchrl.data.llm import History
+ >>> from torchrl.modules.llm.policies import ChatHistory
+ >>>
>>> model = LLM("gpt2")
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
+ >>>
+ >>> # History input (recommended for RL environments)
>>> wrapper = vLLMWrapper(
... model,
- ... from_text=True,
- ... generate=True
+ ... tokenizer=tokenizer,
+ ... input_mode="history",
+ ... generate=True,
+ ... return_log_probs=True
... )
- >>> input_data = LLMData(text=NonTensorStack("Hello, world!", "This is another text"), batch_size=1)
- >>> output_data = wrapper(input_data)
- >>> print(output_data.text_response)
-
- .. seealso:: :func:`~torchrl.modules.TransformersWrapper` for a similar interface using the Hugging Face
- Transformers library.
+ >>>
+ >>> history = History.from_chats([[
+ ... {"role": "user", "content": "Hello"},
+ ... {"role": "assistant", "content": "Hi there!"}
+ ... ]])
+ >>> chat_history = ChatHistory(prompt=history)
+ >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,)))
+ >>> print(result["text"].response) # Generated text
+ >>> print(result["log_probs"].response) # Log probabilities
+ >>> print(result["history"].response) # History with response
+
+ Attributes:
+ collector: The collector associated with the module, if it exists.
+
+ .. seealso::
+ - :class:`~torchrl.modules.llm.policies.LLMWrapperBase` (see :ref:`ref_categorical_sequential`)
+ - :class:`~torchrl.modules.llm.policies.TransformersWrapper` (see :ref:`ref_transformers_wrapper`)
"""
- text_key: NestedKey = ("text",)
- token_key: NestedKey = ("tokens",)
- token_response_key: NestedKey = ("tokens_response",)
- text_response_key: NestedKey = ("text_response",)
- attention_mask_key: NestedKey = ("attention_mask",)
- history_key: NestedKey = ("history",)
-
def __init__(
self,
- model: vllm.LLM, # noqa
- # noqa
+ model: vllm.LLM | str,
*,
- return_log_probs: bool | None = None,
- tokenizer: transformers.tokenization_utils.PreTrainedTokenizer # noqa
- | None = None,
- # noqa
- from_text: bool = True,
- device: torch.device | None = None,
+ tokenizer: callable | str | None = None, # type: ignore
+ input_mode: str = "history",
+ input_key: NestedKey | None = None,
+ attention_mask_key: str = "attention_mask",
generate: bool = True,
generate_kwargs: dict | None = None,
tokenizer_kwargs: dict | None = None,
pad_output: bool = False,
inplace: Literal[True, False, "empty"] | None = None,
- chat_template_name: str | None = None,
+ device: torch.device | None = None,
+ layout: torch.layout | None = None,
+ num_samples: int | None = None,
+ chat_template_name: Literal["chatml_format", "qwen"] | None = None,
chat_template: str | None = None,
+ return_log_probs: bool | None = None,
+ history_key: NestedKey | None = "history",
+ text_key: NestedKey | None = "text",
+ tokens_key: NestedKey | None = "tokens",
+ masks_key: NestedKey | None = "masks",
+ log_probs_key: NestedKey | None = "log_probs",
):
super().__init__()
- import vllm
+ if vllm is None:
+ raise ImportError("vllm is required for vLLMWrapper")
+ if transformers is None:
+ raise ImportError("transformers is required for vLLMWrapper")
+
+ if isinstance(model, str):
+ model = vllm.LLM(model)
+
+ if isinstance(tokenizer, str):
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
+
from vllm import SamplingParams
+ # Validate input_mode
+ if input_mode not in ["history", "text", "tokens"]:
+ raise ValueError(
+ f"input_mode must be one of 'history', 'text', 'tokens'. Got '{input_mode}'"
+ )
+
self.model = model
self._remote_calls = not isinstance(model, vllm.LLM)
- self.from_text = from_text
- self._device = device
+ self.input_mode = input_mode
+ self.attention_mask_key = attention_mask_key
self.generate = generate
+
+ # Auto-determine what to return based on input mode
+ self.return_history = input_mode in ("history",)
+ self.return_text = input_mode in ("text", "history")
+ self.return_tokens = input_mode in ("tokens", "history", "text")
+ self.return_masks = True
+ if return_log_probs is False and not generate:
+ raise ValueError("return_log_probs must be True when generate=False.")
+ return_log_probs = (
+ True
+ if (return_log_probs is None and generate) or (not generate)
+ else bool(return_log_probs)
+ )
+ self.return_log_probs = return_log_probs
+
+ self.history_key = history_key
+ self.log_probs_key = log_probs_key
+ self.masks_key = masks_key
+ self.text_key = text_key
+ self.tokens_key = tokens_key
+
+ if not isinstance(pad_output, bool):
+ raise ValueError("pad_output must be a boolean")
self.pad_output = pad_output
- self.chat_template_name = chat_template_name
- self.chat_template = chat_template
+ self._device = device
+ if not pad_output and layout is None:
+ layout = torch.strided
+ self.layout = layout
padding_value = None
+ # Set input keys based on mode and generate parameter
+ if input_mode == "history":
+ if generate:
+ self.in_keys = [
+ ("history", "prompt") if input_key is None else input_key
+ ]
+ else:
+ self.in_keys = [("history", "full") if input_key is None else input_key]
+ elif input_mode == "text":
+ if generate:
+ self.in_keys = [("text", "prompt") if input_key is None else input_key]
+ else:
+ self.in_keys = [("text", "full") if input_key is None else input_key]
+ elif input_mode == "tokens":
+ if generate:
+ self.in_keys = [
+ ("tokens", "prompt") if input_key is None else input_key
+ ]
+ else:
+ self.in_keys = [("tokens", "full") if input_key is None else input_key]
+ else:
+ raise ValueError(f"Invalid input_mode: {input_mode}")
+ self.input_key = self.in_keys[0]
+
+ # Set output keys based on auto-determined return flags
+ self.out_keys = []
+ if self.return_text:
+ self.out_keys.append(self.text_key)
+ if self.return_masks:
+ self.out_keys.append(self.masks_key)
+ if self.return_tokens:
+ self.out_keys.append(self.tokens_key)
+ if self.return_log_probs:
+ self.out_keys.append(self.log_probs_key)
+ if self.return_history:
+ self.out_keys.append(self.history_key)
+
+ # Tokenizer setup
if not tokenizer_kwargs:
tokenizer_kwargs = {}
if not tokenizer_kwargs.setdefault("return_attention_mask", True):
- raise RuntimeError
+ raise RuntimeError("return_attention_mask must be True")
# If we don't pad, we use lists
return_tensors = "pt" if self.pad_output else False
@@ -180,14 +295,16 @@ def __init__(
raise RuntimeError
self.tokenizer_kwargs = tokenizer_kwargs
- if (pad_output or (from_text and not generate)) and tokenizer is None:
- # We need a tokenizer if we pad or when using text inputs with generate=False
- # The latter case is due to the fact that we want the log-probs for the response only,
- # but if the response is presented as a text we have to tokenize the whole prompt + response and
- # identify where the prompt ends and where the response starts.
- tokenizer = model.get_tokenizer()
+
+ # Get tokenizer if needed
+ if tokenizer is None:
+ try:
+ tokenizer = model.get_tokenizer()
+ except AttributeError:
+ warnings.warn("No tokenizer provided and no tokenizer found in model.")
self.tokenizer = tokenizer
- if tokenizer is not None and (
+
+ if self.tokenizer is not None and (
not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None
):
self.tokenizer.pad_token = self.tokenizer.eos_token
@@ -195,39 +312,42 @@ def __init__(
padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0]
self.padding_value = padding_value
+ # Generate kwargs setup
if generate_kwargs is None:
generate_kwargs = {}
else:
generate_kwargs = dict(generate_kwargs)
- if generate_kwargs.get("n", 1) > 1:
+ self.num_samples = num_samples
+ if generate_kwargs.get("n", 1) > 1 or num_samples is not None:
if inplace in (True, "empty"):
raise ValueError(
"inplace must be False (or None) when generating more than one sample."
)
if inplace is None:
inplace = False
+ if (
+ generate_kwargs.get("n", 1) > 1
+ and num_samples is not None
+ and generate_kwargs.get("n", 1) != num_samples
+ ):
+ raise ValueError("num_samples differs from generate_kwargs['n'].")
+ elif num_samples is None:
+ self.num_samples = generate_kwargs.get("n", 1)
+ generate_kwargs["n"] = self.num_samples
elif inplace is None:
inplace = True
self.inplace = inplace
- prompt_logprobs = False
+ prompt_logprobs = return_log_probs
if not generate:
# We want only the log-probs, we generate a single token (that we then discard)
- # and retrieve the prompt log-probs
+ # and retrieve the prompt log-probs
generate_kwargs["max_tokens"] = 1
- prompt_logprobs = True
- if return_log_probs in (None, True):
- return_log_probs = True
- else:
- raise ValueError(
- "return_log_probs must be True or None when generate=False."
- )
- elif return_log_probs in (None, False):
- return_log_probs = False
- self.return_log_probs = return_log_probs
+ if not return_log_probs:
+ raise ValueError("return_log_probs must be True when generate=False.")
generate_kwargs.setdefault("detokenize", not pad_output)
generate_kwargs.setdefault("prompt_logprobs", prompt_logprobs)
@@ -238,16 +358,144 @@ def __init__(
sampling_params = SamplingParams(**generate_kwargs)
self.sampling_params = sampling_params
- if from_text:
- self.in_keys = [self.text_key]
- else:
- self.in_keys = [self.token_key, self.attention_mask_key]
- self.out_keys = [self.token_response_key]
- if from_text:
- self.out_keys += [self.text_response_key, self.token_key]
- if self.return_log_probs:
- self.out_keys += [self.log_prob_key]
+ # Additional transformers-specific settings
+ self.chat_template_name = chat_template_name
+ self.chat_template = chat_template
+
+ def get_new_version(self, **kwargs):
+ """Returns a new version of the module with altered parameters.
+
+ For instance, the generate parameter can be altered to enable text generation or log-probabilities computation.
+ This is especially useful when one wants to avoid re-initializing the module with a new set of parameters, when the
+ same parameters could be used to gather log-probs.
+
+ Positional arguments are not supported.
+
+ See the class constructor for more details about the parameters.
+ """
+ # Build the constructor arguments by using current values for missing parameters
+ constructor_kwargs = {}
+
+ # Model is always required
+ constructor_kwargs["model"] = kwargs.get("model", self.model)
+
+ # Check for each parameter and use current value if not provided
+ if "tokenizer" in kwargs:
+ constructor_kwargs["tokenizer"] = kwargs["tokenizer"]
+ elif hasattr(self, "tokenizer"):
+ constructor_kwargs["tokenizer"] = self.tokenizer
+
+ if "input_mode" in kwargs:
+ constructor_kwargs["input_mode"] = kwargs["input_mode"]
+ elif hasattr(self, "input_mode"):
+ constructor_kwargs["input_mode"] = self.input_mode
+
+ if "input_key" in kwargs:
+ constructor_kwargs["input_key"] = kwargs["input_key"]
+ # Since the input_key is dynamically determined, we don't want to set it here
+ # elif hasattr(self, "input_key"):
+ # constructor_kwargs["input_key"] = self.input_key
+
+ if "attention_mask_key" in kwargs:
+ constructor_kwargs["attention_mask_key"] = kwargs["attention_mask_key"]
+ elif hasattr(self, "attention_mask_key"):
+ constructor_kwargs["attention_mask_key"] = self.attention_mask_key
+
+ if "generate" in kwargs:
+ constructor_kwargs["generate"] = kwargs["generate"]
+ elif hasattr(self, "generate"):
+ constructor_kwargs["generate"] = self.generate
+
+ if "return_log_probs" in kwargs:
+ constructor_kwargs["return_log_probs"] = kwargs["return_log_probs"]
+ elif not constructor_kwargs.get("generate", True):
+ # if we are not generating, we want to return log-probs
+ constructor_kwargs["return_log_probs"] = True
+ elif hasattr(self, "return_log_probs"):
+ constructor_kwargs["return_log_probs"] = self.return_log_probs
+
+ if "generate_kwargs" in kwargs:
+ constructor_kwargs["generate_kwargs"] = kwargs["generate_kwargs"]
+ elif hasattr(self, "generate_kwargs"):
+ constructor_kwargs["generate_kwargs"] = self.generate_kwargs
+
+ if "pad_output" in kwargs:
+ constructor_kwargs["pad_output"] = kwargs["pad_output"]
+ elif hasattr(self, "pad_output"):
+ constructor_kwargs["pad_output"] = self.pad_output
+
+ if "tokenizer_kwargs" in kwargs:
+ constructor_kwargs["tokenizer_kwargs"] = kwargs["tokenizer_kwargs"]
+ elif hasattr(self, "tokenizer_kwargs"):
+ constructor_kwargs["tokenizer_kwargs"] = dict(self.tokenizer_kwargs)
+ if (
+ "pad_output" in kwargs
+ and kwargs.get("pad_output")
+ != constructor_kwargs["tokenizer_kwargs"]["padding"]
+ ):
+ constructor_kwargs["tokenizer_kwargs"]["padding"] = kwargs.get(
+ "pad_output"
+ )
+ if "inplace" in kwargs:
+ constructor_kwargs["inplace"] = kwargs["inplace"]
+ elif hasattr(self, "inplace"):
+ constructor_kwargs["inplace"] = self.inplace
+
+ if "device" in kwargs:
+ constructor_kwargs["device"] = kwargs["device"]
+ elif hasattr(self, "_device"):
+ constructor_kwargs["device"] = self._device
+
+ if "layout" in kwargs:
+ constructor_kwargs["layout"] = kwargs["layout"]
+ elif hasattr(self, "layout"):
+ constructor_kwargs["layout"] = self.layout
+
+ if "num_samples" in kwargs:
+ constructor_kwargs["num_samples"] = kwargs["num_samples"]
+ elif hasattr(self, "num_samples"):
+ constructor_kwargs["num_samples"] = self.num_samples
+
+ if "chat_template_name" in kwargs:
+ constructor_kwargs["chat_template_name"] = kwargs["chat_template_name"]
+ elif hasattr(self, "chat_template_name"):
+ constructor_kwargs["chat_template_name"] = self.chat_template_name
+
+ if "chat_template" in kwargs:
+ constructor_kwargs["chat_template"] = kwargs["chat_template"]
+ elif hasattr(self, "chat_template"):
+ constructor_kwargs["chat_template"] = self.chat_template
+
+ if "history_key" in kwargs:
+ constructor_kwargs["history_key"] = kwargs["history_key"]
+ elif hasattr(self, "history_key"):
+ constructor_kwargs["history_key"] = self.history_key
+
+ if "text_key" in kwargs:
+ constructor_kwargs["text_key"] = kwargs["text_key"]
+ elif hasattr(self, "text_key"):
+ constructor_kwargs["text_key"] = self.text_key
+
+ if "tokens_key" in kwargs:
+ constructor_kwargs["tokens_key"] = kwargs["tokens_key"]
+ elif hasattr(self, "tokens_key"):
+ constructor_kwargs["tokens_key"] = self.tokens_key
+
+ if "masks_key" in kwargs:
+ constructor_kwargs["masks_key"] = kwargs["masks_key"]
+ elif hasattr(self, "masks_key"):
+ constructor_kwargs["masks_key"] = self.masks_key
+
+ if "log_probs_key" in kwargs:
+ constructor_kwargs["log_probs_key"] = kwargs["log_probs_key"]
+ elif hasattr(self, "log_probs_key"):
+ constructor_kwargs["log_probs_key"] = self.log_probs_key
+
+ # Create and return new instance
+ return type(self)(**constructor_kwargs)
+
+ @set_list_to_stack(True)
def forward(
self,
tensordict: TensorDictBase,
@@ -265,415 +513,1276 @@ def forward(
elif tensordict.ndim > 1:
return self(tensordict.reshape(-1)).view(tensordict.shape)
- if kwargs:
- sampling_params = self.sampling_params.clone()
- for key, val in kwargs.items():
- setattr(sampling_params, key, val)
- else:
- sampling_params = self.sampling_params
-
_source_device = None
if self._device:
_source_device = tensordict.device
if tensordict.device:
tensordict = tensordict.copy().clear_device_()
- out = LazyStackedTensorDict(
- *[
+ if kwargs:
+ from vllm import SamplingParams
+
+ sampling_params = SamplingParams(**kwargs)
+ else:
+ sampling_params = self.sampling_params
+
+ if self.num_samples is not None:
+ out = (
TensorDict(
- device=tensordict.device, batch_size=tensordict.batch_size[1:]
+ device=tensordict.device,
+ batch_size=(
+ tensordict.batch_size[0],
+ self.num_samples,
+ *tensordict.batch_size[1:],
+ ),
)
- for _ in range(tensordict.shape[0])
- ]
- )
- if self.from_text:
+ .to_lazystack(1)
+ .to_lazystack(0)
+ )
+ else:
+ out = TensorDict(
+ device=tensordict.device, batch_size=tensordict.batch_size
+ ).to_lazystack(0)
+
+ if self.input_mode == "history":
if self.generate:
- out = self._from_vllm_generate_text(
- tensordict, sampling_params=sampling_params, out=out
- )
+ out = self._from_vllm_generate_history(tensordict, sampling_params, out)
else:
- out = self._from_vllm_logprobs_text(
- tensordict, sampling_params=sampling_params, out=out
- )
- else:
+ out = self._from_vllm_logprobs_history(tensordict, sampling_params, out)
+ elif self.input_mode == "text":
if self.generate:
- out = self._from_vllm_generate_tokens(
- tensordict, sampling_params=sampling_params, out=out
- )
+ out = self._from_vllm_generate_text(tensordict, sampling_params, out)
else:
- out = self._from_vllm_logprobs_tokens(
- tensordict, sampling_params=sampling_params, out=out
- )
+ out = self._from_vllm_logprobs_text(tensordict, sampling_params, out)
+ elif self.input_mode == "tokens":
+ if self.generate:
+ out = self._from_vllm_generate_tokens(tensordict, sampling_params, out)
+ else:
+ out = self._from_vllm_logprobs_tokens(tensordict, sampling_params, out)
+
if _source_device:
out = out.to(_source_device)
if tensordict_out is None:
if self.inplace is True:
+ # The output is the input
tensordict_out = tensordict
elif self.inplace is False:
+ # The output is the new structure
tensordict_out = out
elif self.inplace == "empty":
+ # The output is empty
tensordict_out = tensordict.empty()
if tensordict_out is not None and tensordict_out is not out:
- result = tensordict_out
+ result = tensordict_out.exclude(*self.out_keys, inplace=True)
result.update(out, keys_to_update=self.out_keys)
- elif tensordict_out is not out:
+ elif tensordict_out is out:
+ result = out.select(*self.out_keys)
+ elif self.inplace:
result = out
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
- return tensordict.update(result, keys_to_update=keys)
+ result = tensordict.exclude(*self.out_keys, inplace=True).update(
+ result, keys_to_update=keys
+ )
else:
result = out
return result
- def _from_vllm_generate_text(self, td, sampling_params, out) -> TensorDictBase:
- kwargs = {"sampling_params": sampling_params}
- args = ()
- input_ids = None
- attention_mask = None
- text = td.get(self.text_key)
- if text is None:
- # Fallback on history parsing
- history = td.get(self.history_key)
- if history is None:
- raise ValueError("No text or history provided to the vLLMWrapper.")
- tokenizer_kwargs = {}
- if self.chat_template_name is not None:
- tokenizer_kwargs["chat_template_name"] = self.chat_template_name
- if self.chat_template is not None:
- tokenizer_kwargs["chat_template"] = self.chat_template
- text = history.apply_chat_template(self.tokenizer, **tokenizer_kwargs)
- if self.pad_output:
- tokenizer_kwargs = self.tokenizer_kwargs
- if not isinstance(text, (list, str)):
- text = text.tolist()
- tokens_in = TensorDict.from_dict(self.tokenizer(text, **tokenizer_kwargs))
- # out.set("tokens_in", tokens_in)
- input_ids, attention_mask = (
- tokens_in["input_ids"],
- tokens_in["attention_mask"],
- )
- prompt_token_ids = self._to_list(input_ids, attention_mask)
- kwargs["prompt_token_ids"] = prompt_token_ids
- else:
- text = td.get(self.text_key)
- if not isinstance(text, (list, str)):
- text = text.tolist()
- args = (text,)
+ def _from_vllm_generate_history(
+ self,
+ tensordict_input: TensorDictBase,
+ sampling_params: SamplingParams,
+ out: TensorDictBase,
+ ) -> TensorDictBase:
+ """Generate text from history input."""
+ from torchrl.data.llm import History
+
+ assert isinstance(
+ tensordict_input, TensorDictBase
+ ), f"tensordict_input must be TensorDictBase, got {type(tensordict_input)}"
+ assert isinstance(
+ sampling_params, SamplingParams
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
+ assert isinstance(
+ out, TensorDictBase
+ ), f"out must be TensorDictBase, got {type(out)}"
+
+ # Validate input
+ if self.input_key not in tensordict_input:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for history input mode, "
+ f"but found keys: {list(tensordict_input.keys())}"
+ )
- if not self._remote_calls:
- tokens_out = self.model.generate(*args, **kwargs)
+ history = tensordict_input.get(self.input_key)
+ if not isinstance(history, History):
+ raise TypeError(
+ f"Expected History object for '{self.input_key}', got {type(history)}"
+ )
+
+ # Apply chat template
+ tokenizer_kwargs = {}
+ if self.chat_template_name is not None:
+ tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
+ if self.chat_template is not None:
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
+ tokenizer_kwargs.setdefault("add_generation_prompt", True)
+ text_prompt = history.apply_chat_template(
+ tokenizer=self.tokenizer, **tokenizer_kwargs
+ )
+
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", False)
+ tokenizer_kwargs.setdefault("tokenize", True)
+ tokenizer_kwargs.setdefault("padding", False)
+ tokenizer_kwargs.setdefault("return_dict", True)
+ response_struct = history.apply_chat_template(
+ tokenizer=self.tokenizer, **tokenizer_kwargs
+ )
+ tokens_prompt_padded = None
+ tokens_prompt_unpadded = None
+ if self.pad_output:
+ tokens_prompt_padded = response_struct.get(
+ "input_ids",
+ as_padded_tensor=True,
+ padding_value=self.padding_value,
+ padding_side="left",
+ )
else:
- import ray
+ tokens_prompt_unpadded = response_struct.get("input_ids", as_list=True)
- tokens_out = ray.get(self.model.generate.remote(*args, **kwargs))
+ result = self._generate_from_tokens(
+ tokens_prompt_padded=tokens_prompt_padded,
+ tokens_prompt_unpadded=tokens_prompt_unpadded,
+ sampling_params=sampling_params,
+ out=out,
+ )
- tokens_out = self._get_output_tokens_and_log_probs(tokens_out)
+ # Generate using text path
if self.pad_output:
- tokens_out.set(
- self.text_response_key,
- NonTensorStack(
- *self.tokenizer.batch_decode(tokens_out[self.token_response_key])
- ),
- )
- in_keys = [
- self.log_prob_key,
- self.token_response_key,
- self.text_response_key,
- self.token_key,
- self.attention_mask_key,
- ]
- out = out.update(tokens_out.select(*in_keys, strict=False))
- # We might already have the tokens
- if input_ids is not None and self.token_key not in out:
- out[self.token_key] = input_ids
- if attention_mask is not None and self.attention_mask_key not in out:
- out[self.attention_mask_key] = attention_mask
- inputs = td.select(*self.in_keys, strict=False)
- if inputs.ndim < out.ndim:
- # This happens when n > 1
- inputs = inputs.unsqueeze(-1).expand(out.shape)
- out.update(inputs)
- return out
+ result[(self.tokens_key, "prompt")] = (
+ tokens_prompt_padded
+ if not self.num_samples
+ else tokens_prompt_padded.unsqueeze(1).repeat(1, self.num_samples, 1)
+ )
+ else:
+ tokens_prompt_nested = torch.nested.as_nested_tensor(tokens_prompt_unpadded)
+ if not self.num_samples:
+ result[(self.tokens_key, "prompt")] = tokens_prompt_nested
+ else:
+ for r in result.unbind(1):
+ r[(self.tokens_key, "prompt")] = tokens_prompt_nested
- def _from_vllm_logprobs_text(self, td, sampling_params, out):
- text_prompt = td.get(self.text_key)
- text_response = td.get(self.text_response_key)
- if text_response is None or text_prompt is None:
- if text_response is not None and text_prompt is not None:
- raise ValueError(
- "No text or history provided to the vLLMWrapper. Either both are provided or none of them."
- )
- # Fallback on history parsing
- history = td.get(self.history_key)
- if history is None:
- raise ValueError(
- "No text or history provided to the TransformersWrapper."
+ text_result = Text._from_tensordict(result.empty())
+ result.set(self.text_key, text_result)
+ if not self.num_samples:
+ text_result.prompt = text_prompt
+ else:
+ for r in result.unbind(1):
+ r[self.text_key, "prompt"] = text_prompt
+ with result.view(-1) as result_flat:
+ if self.pad_output:
+ tokens_full_padded = result_flat.get(
+ (self.tokens_key, "full"),
+ as_padded_tensor=True,
+ padding_side="right",
+ padding_value=self.padding_value,
)
- tokenizer_kwargs = {}
- if self.chat_template_name is not None:
- tokenizer_kwargs.setdefault(
- "chat_template_name", self.chat_template_name
+ if tokens_full_padded is None:
+ raise ValueError("tokens_full_padded is None")
+ text_full = self.tokenizer.batch_decode(
+ tokens_full_padded, skip_special_tokens=False
)
- if self.chat_template is not None:
- tokenizer_kwargs.setdefault("chat_template", self.chat_template)
- tokenizer_kwargs.setdefault("add_generation_prompt", False)
- text_response = history.apply_chat_template(
- tokenizer=self.tokenizer, **tokenizer_kwargs
- )
- if isinstance(text_response, list):
- text_prompt = ["" for _ in text_response]
else:
- text_prompt = ""
- if not isinstance(text_prompt, list):
- text_prompt = text_prompt.tolist()
- if not isinstance(text_response, list):
- text_response = text_response.tolist()
- text = [_x + _y for _x, _y in _zip_strict(text_prompt, text_response)]
-
- tokenized_total = self.tokenizer(text, **self.tokenizer_kwargs)
- tokenized_prompt_only = self.tokenizer(text_prompt, **self.tokenizer_kwargs)
-
- input_ids_total = tokenized_total["input_ids"]
- attention_mask_total = tokenized_total["attention_mask"]
-
- if not self.pad_output:
- input_ids_prompt = tokenized_prompt_only["input_ids"]
- attention_mask_prompt = tokenized_prompt_only["attention_mask"]
- input_ids_response = []
- for token_total, token_prompt in zip(input_ids_total, input_ids_prompt):
- input_ids_response.append(token_total[len(token_prompt) :])
- attention_mask_response = []
- for mask, mask_prompt in zip(attention_mask_total, attention_mask_prompt):
- attention_mask_response.append(mask[len(mask_prompt) :])
- else:
- input_ids_prompt: torch.Tensor = tokenized_prompt_only["input_ids"]
- # attention_mask_prompt: torch.Tensor = tokenized_prompt_only[
- # "attention_mask"
- # ]
- input_ids_response: torch.Tensor = input_ids_total[
- :, input_ids_prompt.shape[1] :
+ tokens_full_unpadded = result_flat.get(
+ (self.tokens_key, "full"), as_list=True
+ )
+ # print("shapes of assistant masks", [t.shape for t in result_flat.get(("masks", "all_assistant_mask"), as_list=True)])
+ if tokens_full_unpadded is None:
+ raise ValueError("tokens_full_unpadded is None")
+ text_full = self.tokenizer.batch_decode(
+ tokens_full_unpadded, skip_special_tokens=False
+ )
+ text_prompt = result_flat[self.text_key, "prompt"]
+ text_response = [
+ txt[len(prompt) :]
+ for txt, prompt in _zip_strict(text_full, text_prompt)
]
- # response_attention_mask: torch.Tensor = attention_mask_total[
- # :, attention_mask_prompt.shape[1] :
- # ]
+ result_flat.set((self.text_key, "full"), text_full)
+ result_flat.set((self.text_key, "response"), text_response)
+
+ # Now parse the full text back to a history object, and use the extra history objects
+ # as response
+ history_chat = ChatHistory._from_tensordict(result.empty())
+ if self.num_samples is None:
+ history_chat.prompt = history
+ else:
+ for h in history_chat.unbind(1):
+ h.prompt = history
+ with history_chat.view(-1) as history_chat_flat:
+ history_chat_flat.full = full_histories = History.from_text(text_full)
+ prompt_histories = history_chat_flat.prompt
+ # iterate over batch
+ h_responses = []
+ for h_full, h_prompt in _zip_strict(
+ full_histories.unbind(0), prompt_histories.unbind(0)
+ ):
+ if h_full.shape[0] <= h_prompt.shape[0]:
+ raise RuntimeError("Full history is shorter than prompt history")
+ # Note: there can be more than one response, so the response has the same number of dims as prompt
+ h_responses.append(h_full[h_prompt.shape[0] :])
+ history_chat_flat.response = torch.stack(h_responses)
+ result.set(self.history_key, history_chat)
+ return result
- input_ids_total = self._to_list(input_ids_total, attention_mask_total)
- kwargs = {"sampling_params": sampling_params}
- if self.tokenizer is not None:
- kwargs.update({"prompt_token_ids": input_ids_total})
- args = ()
+ def _from_vllm_logprobs_history(
+ self,
+ tensordict_input: TensorDictBase,
+ sampling_params: SamplingParams,
+ out: TensorDictBase,
+ ) -> TensorDictBase:
+ """Compute log-probs from history input."""
+ assert isinstance(
+ tensordict_input, TensorDictBase
+ ), f"tensordict_input must be TensorDictBase, got {type(tensordict_input)}"
+ assert isinstance(
+ sampling_params, SamplingParams
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
+ assert isinstance(
+ out, TensorDictBase
+ ), f"out must be TensorDictBase, got {type(out)}"
+
+ from torchrl.data.llm import History
+
+ # Validate input
+ if self.input_key not in tensordict_input:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for history input mode, "
+ f"but found keys: {list(tensordict_input.keys())}"
+ )
+
+ history = tensordict_input.get(self.input_key)
+ if not isinstance(history, History):
+ raise TypeError(
+ f"Expected History object for '{self.input_key}', got {type(history)}"
+ )
+
+ # Apply chat template
+ tokenizer_kwargs = {}
+ if self.chat_template_name is not None:
+ tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name)
+ if self.chat_template is not None:
+ tokenizer_kwargs.setdefault("chat_template", self.chat_template)
+ tokenizer_kwargs.setdefault("add_generation_prompt", False)
+ text_full = history.apply_chat_template(
+ tokenizer=self.tokenizer, **tokenizer_kwargs
+ )
+ tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True)
+ tokenizer_kwargs.setdefault("tokenize", True)
+ tokenizer_kwargs.setdefault("padding", False)
+ tokenizer_kwargs.setdefault("return_dict", True)
+ response_struct = history.apply_chat_template(
+ tokenizer=self.tokenizer, **tokenizer_kwargs
+ )
+
+ result = self._logprobs_from_tokens(
+ response_struct=response_struct, sampling_params=sampling_params, out=out
+ )
+ text_result = Text._from_tensordict(result.empty())
+ result.set(self.text_key, text_result)
+ result[self.text_key, "full"] = text_full
+ result.set(self.history_key, ChatHistory(full=history))
+ return result
+
+ def _from_vllm_generate_text(
+ self, td: TensorDictBase, sampling_params: SamplingParams, out: TensorDictBase
+ ) -> TensorDictBase:
+ """Generate text from text input."""
+ # Type assertions
+ assert isinstance(
+ td, TensorDictBase
+ ), f"td must be TensorDictBase, got {type(td)}"
+ assert isinstance(
+ sampling_params, SamplingParams
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
+ assert isinstance(
+ out, TensorDictBase
+ ), f"out must be TensorDictBase, got {type(out)}"
+
+ # Validate input
+ if self.input_key not in td:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for text input mode, "
+ f"but found keys: {list(td.keys())}"
+ )
+
+ text = td.get(self.input_key)
+ if text is None:
+ raise ValueError(f"Expected '{self.input_key}' key for text input mode")
+
+ return self._generate_from_text(text, sampling_params, out)
+
+ def _from_vllm_logprobs_text(
+ self, td: TensorDictBase, sampling_params: SamplingParams, out: TensorDictBase
+ ) -> TensorDictBase:
+ """Compute log-probs from text input."""
+ # Type assertions
+ assert isinstance(
+ td, TensorDictBase
+ ), f"td must be TensorDictBase, got {type(td)}"
+ assert isinstance(
+ sampling_params, SamplingParams
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
+ assert isinstance(
+ out, TensorDictBase
+ ), f"out must be TensorDictBase, got {type(out)}"
+
+ # Validate input
+ if self.input_key not in td:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for text input mode, "
+ f"but found keys: {list(td.keys())}"
+ )
+
+ text = td.get(self.input_key)
+ if text is None:
+ raise ValueError(f"Expected '{self.input_key}' key for text input mode")
+
+ return self._logprobs_from_text(text, sampling_params, out)
+
+ def _from_vllm_generate_tokens(
+ self, td: TensorDictBase, sampling_params: SamplingParams, out: TensorDictBase
+ ) -> TensorDictBase:
+ """Generate text from tokens input."""
+ # Type assertions
+ assert isinstance(
+ td, TensorDictBase
+ ), f"td must be TensorDictBase, got {type(td)}"
+ assert isinstance(
+ sampling_params, SamplingParams
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
+ assert isinstance(
+ out, TensorDictBase
+ ), f"out must be TensorDictBase, got {type(out)}"
+
+ # Validate input
+ if self.input_key not in td:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for tokens input mode, "
+ f"but found keys: {list(td.keys())}"
+ )
+
+ tokens_prompt_padded = None
+ tokens_prompt_unpadded = None
+ if self.pad_output:
+ tokens_prompt_padded = td.get(self.input_key)
+ else:
+ tokens_prompt_unpadded = list(td.get(self.input_key, as_list=True))
+ # make sure we remove the padding tokens
+ tokens_prompt_unpadded = [
+ tokens[tokens != self.padding_value]
+ for tokens in tokens_prompt_unpadded
+ ]
+
+ return self._generate_from_tokens(
+ tokens_prompt_unpadded=tokens_prompt_unpadded,
+ tokens_prompt_padded=tokens_prompt_padded,
+ sampling_params=sampling_params,
+ out=out,
+ )
+
+ def _from_vllm_logprobs_tokens(
+ self, td: TensorDictBase, sampling_params: SamplingParams, out: TensorDictBase
+ ) -> TensorDictBase:
+ """Compute log-probs from tokens input."""
+ # Type assertions
+ assert isinstance(
+ td, TensorDictBase
+ ), f"td must be TensorDictBase, got {type(td)}"
+ assert isinstance(
+ sampling_params, SamplingParams
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
+ assert isinstance(
+ out, TensorDictBase
+ ), f"out must be TensorDictBase, got {type(out)}"
+
+ # Validate input
+ if self.input_key not in td:
+ raise ValueError(
+ f"Expected '{self.input_key}' key for tokens input mode, "
+ f"but found keys: {list(td.keys())}"
+ )
+
+ tokens_full_padded = None
+ tokens_full_unpadded = None
+ if self.pad_output:
+ tokens_full_padded = td.get(self.input_key)
else:
- # TODO: this is unreachable as of now - but ultimately we may want to pass the text directly
- args = (td[self.text_key],)
+ tokens_full_unpadded = list(td.get(self.input_key, as_list=True))
+ # make sure we remove the padding tokens
+ tokens_full_unpadded = [
+ tokens[tokens != self.padding_value] for tokens in tokens_full_unpadded
+ ]
+
+ return self._logprobs_from_tokens(
+ response_struct=None,
+ tokens_full_unpadded=tokens_full_unpadded,
+ tokens_full_padded=tokens_full_padded,
+ sampling_params=sampling_params,
+ out=out,
+ )
+
+ def _cat_text(
+ self, text: str | list[str], response_text: str | list[str]
+ ) -> str | list[str]:
+ """Concatenate text and response text."""
+ assert isinstance(
+ text, (str, list)
+ ), f"text must be str or list, got {type(text)}"
+ assert isinstance(
+ response_text, (str, list)
+ ), f"response_text must be str or list, got {type(response_text)}"
+
+ if isinstance(text, list):
+ return [self._cat_text(t, t_) for t, t_ in _zip_strict(text, response_text)]
+ else:
+ return text + response_text
+
+ def _generate_from_text(
+ self,
+ text: str | list[str] | NonTensorStack,
+ sampling_params: SamplingParams,
+ out: TensorDictBase,
+ ) -> TensorDictBase:
+ """Generate text from text input."""
+ # Convert text to list format
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list):
+ text = text.tolist()
+
+ assert isinstance(
+ text, (str, list)
+ ), f"text must be str or list, got {type(text)}"
+ assert isinstance(
+ sampling_params, SamplingParams
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
+ assert isinstance(
+ out, TensorDictBase
+ ), f"out must be TensorDictBase, got {type(out)}"
+
+ generate_kwargs = {"sampling_params": sampling_params}
+ args = ()
+
+ # Convert text to list format
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list):
+ text = text.tolist()
+
if not self._remote_calls:
- tokens_out = self.model.generate(*args, **kwargs)
+ request_output = self.model.generate(text, *args, **generate_kwargs)
else:
import ray
- tokens_out = ray.get(self.model.generate.remote(*args, **kwargs))
+ request_output = ray.get(
+ self.model.generate.remote(text, *args, **generate_kwargs)
+ )
- tokens_out = _RequestOutput_tc.from_request_output(tokens_out)
- tokens_out = tokens_out.select(
- "prompt_token_ids", "prompt_logprobs", strict=False
- )._tensordict
+ request_output_tc = _RequestOutput_tc.from_request_output(request_output)
- # we disregard the tokens from the prompt to focus on those of the response
+ # Extract response tokens and text
+ outputs = (
+ request_output_tc.outputs.view(-1)
+ if self.num_samples is not None
+ else request_output_tc.outputs
+ )
if self.pad_output:
- lps = tokens_out.get(
- "prompt_logprobs", as_padded_tensor=True, padding_side="left"
+ response_tokens_padded = outputs.view(-1).get(
+ "token_ids",
+ as_padded_tensor=self.pad_output,
+ padding_value=self.padding_value,
+ padding_side="right",
+ )
+ response_tokens_list = outputs.view(-1).get(
+ "token_ids",
+ as_list=True,
+ )
+ self._check_not_padded(response_tokens_list)
+ if self.tokenizer is not None:
+ response_text = self.tokenizer.batch_decode(
+ response_tokens_list, skip_special_tokens=False
)
- lps = lps[..., -input_ids_response.shape[1] :]
- padded = input_ids_response == self.padding_value
- lps = torch.where(~padded, lps, 0.0)
else:
- lps = tokens_out.get(
- "prompt_logprobs",
- as_list=True,
- )
- # We use a nested tensor as it will be unbound during writing
- lps = torch.nested.nested_tensor(
- [lp[..., -len(tr) :] for lp, tr in zip(lps, input_ids_response)]
- )
-
- out = out.update(tokens_out.empty(recurse=True))
- if isinstance(input_ids_response, list):
- input_ids_response = torch.nested.nested_tensor(input_ids_response)
- out["tokens_response"] = input_ids_response
- out[self.log_prob_key] = lps
- inputs = td.select(*self.in_keys, strict=False)
- if inputs.ndim < out.ndim:
- # This happens when n > 1
- inputs = inputs.unsqueeze(-1).expand(out.shape)
- out.update(inputs)
+ response_text = None
+
+ # Build output TensorClass objects
+
+ masks_obj = Masks._from_tensordict(out.empty())
+ masks_obj.all_attention_mask = None
+ masks_obj.all_assistant_mask = None
+ masks_obj.padded = MetaData(self.pad_output)
+ out.set(self.masks_key, masks_obj)
+
+ if self.num_samples is not None:
+ text = [txt for txt in text for _ in range(self.num_samples)]
+ text_obj = Text._from_tensordict(out.empty())
+ with text_obj.view(-1) as text_obj_flat:
+ text_obj_flat.prompt = text
+ text_obj_flat.response = response_text
+ text_obj_flat.full = self._cat_text(text, response_text)
+ out.set(self.text_key, text_obj)
+
+ tokens_obj = Tokens._from_tensordict(out.empty())
+ with tokens_obj.view(-1) as tokens_obj_flat:
+ tokens_obj_flat.prompt = None # We don't have prompt tokens in this path
+ if self.pad_output:
+ tokens_obj_flat.response = response_tokens_padded
+ self._check_padded(response_tokens_padded)
+ else:
+ tokens_obj_flat.response = response_tokens_list
+ self._check_not_padded(response_tokens_list)
+ tokens_obj_flat.full = (
+ None # we don't have prompt tokens in this path so no all_tokens either
+ )
+ tokens_obj.padded = MetaData(self.pad_output)
+ out.set(self.tokens_key, tokens_obj)
+
+ if self.return_log_probs:
+ log_probs_obj = LogProbs._from_tensordict(out.empty())
+ with log_probs_obj.view(-1) as log_probs_obj_flat:
+ if self.pad_output:
+ log_probs_padded = outputs.get(
+ "logprobs",
+ as_padded_tensor=self.pad_output,
+ padding_value=self.padding_value,
+ padding_side="right",
+ )
+ self._check_padded(log_probs_padded)
+ log_probs_obj_flat.response = log_probs_padded
+ log_probs_obj_flat.full = log_probs_padded
+ else:
+ log_probs_list = outputs.get(
+ "logprobs",
+ as_list=True,
+ )
+ self._check_not_padded(log_probs_list)
+ log_probs_obj_flat.response = log_probs_list
+ log_probs_obj_flat.full = log_probs_list
+ log_probs_obj_flat.prompt = None
+ log_probs_obj.padded = MetaData(self.pad_output)
+ out.set(self.log_probs_key, log_probs_obj)
+
return out
- def _from_vllm_generate_tokens(self, td, sampling_params, out):
- input_ids = td.get(self.token_key)
- attention_mask = td.get(self.attention_mask_key)
- input_ids_list = self._to_list(input_ids, attention_mask)
- args = ()
- kwargs = {
+ def _logprobs_from_text(
+ self,
+ text: str | list[str] | NonTensorStack,
+ sampling_params: SamplingParams,
+ out: TensorDictBase,
+ ) -> TensorDictBase:
+ """Compute log-probs from text input."""
+ # Convert text to list format
+ if isinstance(text, str):
+ text = [text]
+ elif not isinstance(text, list):
+ text = text.tolist()
+
+ assert isinstance(
+ text, (str, list)
+ ), f"text must be str or list, got {type(text)}"
+ assert isinstance(
+ sampling_params, SamplingParams
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
+ assert isinstance(
+ out, TensorDictBase
+ ), f"out must be TensorDictBase, got {type(out)}"
+
+ # Tokenize the text
+ if self.tokenizer is None:
+ raise ValueError(
+ "Tokenizer is required for log-probs computation with text input"
+ )
+
+ # Tokenize the text
+ tokenized_output = self.tokenizer(text, **self.tokenizer_kwargs)
+ if self.pad_output:
+ tokens_full_padded = tokenized_output["input_ids"]
+ attention_mask_full_padded = tokenized_output["attention_mask"]
+ tokens_full_list = self._to_list(
+ tokens_full_padded, attention_mask_full_padded
+ )
+ else:
+ tokens_full_unpadded = tokenized_output["input_ids"]
+ tokens_full_list = self._to_list(tokens_full_unpadded, None)
+ attention_mask_full_unpadded = tokenized_output["attention_mask"]
+ attention_mask_full_unpadded = [
+ am.bool()
+ if isinstance(am, torch.Tensor)
+ else torch.tensor(am, dtype=torch.bool)
+ for am in attention_mask_full_unpadded
+ ]
+
+ # Convert to list format for vLLM
+ generate_kwargs = {
"sampling_params": sampling_params,
- "prompt_token_ids": input_ids_list,
+ "prompt_token_ids": tokens_full_list,
}
+
+ # Generate with vLLM to get prompt_logprobs
if not self._remote_calls:
- tokens_out = self.model.generate(*args, **kwargs)
+ request_output = self.model.generate(**generate_kwargs)
else:
import ray
- tokens_out = ray.get(self.model.generate.remote(*args, **kwargs))
- tokens_out = _RequestOutput_tc.from_request_output(tokens_out)
- # When not generate, we don't want to overwrite this
- tokens_response_td = tokens_out.outputs._tensordict.select(
- "token_ids", "logprobs", strict=False
+ request_output = ray.get(self.model.generate.remote(**generate_kwargs))
+
+ request_output_tc = _RequestOutput_tc.from_request_output(request_output)
+
+ # Extract log-probs from prompt_logprobs
+ if self.pad_output:
+ # For padded case, use all prompt_logprobs
+ log_probs_full_padded = request_output_tc.get(
+ "prompt_logprobs",
+ as_padded_tensor=True,
+ padding_value=0,
+ padding_side="left",
+ )
+
+ # Mask out padding
+ attention_mask_full_padded = tokens_full_padded != self.padding_value
+ log_probs_full_padded = torch.where(
+ attention_mask_full_padded, log_probs_full_padded, 0.0
+ )
+ else:
+ # For unpadded case, extract from each sequence
+ log_probs_full_unpadded = request_output_tc.get(
+ "prompt_logprobs", as_list=True
+ )
+ self._check_not_padded(log_probs_full_unpadded)
+
+ masks_obj = Masks._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
)
if self.pad_output:
- tokens_response_td = tokens_response_td.densify(
- layout=torch.strided
- ).to_padded_tensor(padding=self.padding_value)
- tokens_response_td.rename_key_("token_ids", "tokens_response")
+ self._check_padded(attention_mask_full_padded)
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
+ else:
+ self._check_not_padded(attention_mask_full_unpadded)
+ masks_obj.all_attention_mask = attention_mask_full_unpadded
+ masks_obj.padded = MetaData(self.pad_output)
+ out.set(self.masks_key, masks_obj)
+
+ # Build output TensorClass objects
+ text_obj = Text._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ text_obj.prompt = None
+ text_obj.response = None
+ text_obj.full = text
+ out.set(self.text_key, text_obj)
+
+ tokens_obj = Tokens._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if self.pad_output:
+ self._check_padded(tokens_full_padded)
+ tokens_obj.full = tokens_full_padded
+ else:
+ tokens_obj.full = tokens_full_unpadded
+ tokens_obj.response = None
+ tokens_obj.padded = MetaData(self.pad_output)
+ out.set(self.tokens_key, tokens_obj)
+
if self.return_log_probs:
- tokens_response_td.rename_key_("logprobs", self.log_prob_key)
+ log_probs_obj = LogProbs._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
if self.pad_output:
- padded_values = (
- tokens_response_td["tokens_response"] == self.padding_value
- )
- if padded_values.any():
- lps = tokens_response_td[self.log_prob_key]
- lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0)
- tokens_response_td[self.log_prob_key] = lps
- out = out.update(tokens_response_td.empty(recurse=True))
- out.update(
- tokens_response_td,
- keys_to_update=(self.token_response_key, self.log_prob_key),
- )
- inputs = td.select(*self.in_keys, strict=False)
- if inputs.ndim < out.ndim:
- # This happens when n > 1
- inputs = inputs.unsqueeze(-1).expand(out.shape)
- out.update(inputs)
+ self._check_padded(log_probs_full_padded)
+ log_probs_obj.full = log_probs_full_padded
+ else:
+ self._check_not_padded(log_probs_full_unpadded)
+ log_probs_obj.full = log_probs_full_unpadded
+ log_probs_obj.response = None
+ log_probs_obj.padded = MetaData(self.pad_output)
+ out.set(self.log_probs_key, log_probs_obj)
+
return out
- def _from_vllm_logprobs_tokens(self, td, sampling_params, out):
+ def _cat_tensors(
+ self,
+ tokens: list[torch.Tensor] | torch.Tensor,
+ response_tokens: list[torch.Tensor] | torch.Tensor,
+ ) -> list[torch.Tensor] | torch.Tensor:
+ """Concatenate tokens and response tokens."""
+ if isinstance(tokens, list) or isinstance(response_tokens, list):
+ return [
+ self._cat_tensors(t, t_)
+ for t, t_ in _zip_strict(tokens, response_tokens)
+ ]
+ else:
+ return torch.cat([tokens, response_tokens], dim=-1)
- tokens = td.get(self.token_key)
- tokens_response = td.get(self.token_response_key)
- attention_mask = td.get(self.attention_mask_key)
+ def _generate_from_tokens(
+ self,
+ tokens_prompt_unpadded: list[torch.Tensor] | None,
+ tokens_prompt_padded: torch.Tensor | None,
+ sampling_params: SamplingParams,
+ out: TensorDictBase,
+ ) -> TensorDictBase:
+ """Generate text from tokens input."""
+ assert isinstance(
+ tokens_prompt_padded, (torch.Tensor, type(None))
+ ), f"tokens_prompt_padded must be torch.Tensor or None, got {type(tokens_prompt_padded)}"
+ assert isinstance(
+ tokens_prompt_unpadded, (list, type(None))
+ ), f"tokens_prompt_unpadded must be list or None, got {type(tokens_prompt_unpadded)}"
+ assert isinstance(
+ sampling_params, SamplingParams
+ ), f"sampling_params must be SamplingParams, got {type(sampling_params)}"
+ assert isinstance(
+ out, TensorDictBase
+ ), f"out must be TensorDictBase, got {type(out)}"
+
+ generate_kwargs = {"sampling_params": sampling_params}
+ args = ()
- tokens = torch.cat([tokens, tokens_response], -1)
- if attention_mask is not None:
- attention_mask = torch.cat(
- [attention_mask, attention_mask.new_ones(tokens_response.shape)], -1
+ if tokens_prompt_unpadded is None:
+ # TODO: To be on the safe side, we may do this even in the unpadded case since we're not sure
+ # the user passed an unpadded tensor in the first place.
+ tokens_prompt_list = self._to_list(
+ tokens_prompt_padded, tokens_prompt_padded != self.padding_value
)
- input_ids_list = self._to_list(tokens, attention_mask)
- args = ()
- kwargs = {
+ else:
+ tokens_prompt_list = self._to_list(tokens_prompt_unpadded, None)
+ generate_kwargs.update({"prompt_token_ids": tokens_prompt_list})
+
+ if not self._remote_calls:
+ request_output = self.model.generate(*args, **generate_kwargs)
+ else:
+ import ray
+
+ request_output = ray.get(
+ self.model.generate.remote(*args, **generate_kwargs)
+ )
+
+ request_output_tc = _RequestOutput_tc.from_request_output(request_output)
+
+ # Extract response tokens and text
+ outputs = (
+ request_output_tc.outputs.view(-1)
+ if self.num_samples is not None
+ else request_output_tc.outputs
+ )
+ if self.pad_output:
+ tokens_response_padded = outputs.get(
+ "token_ids",
+ as_padded_tensor=self.pad_output,
+ padding_value=self.padding_value,
+ padding_side="right",
+ )
+ self._check_padded(tokens_response_padded)
+ tokens_response_unpadded = outputs.get(
+ "token_ids",
+ as_list=True,
+ )
+ self._check_not_padded(tokens_response_unpadded)
+
+ tokens_obj = Tokens._from_tensordict(out.empty())
+ if self.pad_output:
+ self._check_padded(tokens_response_padded)
+ self._check_padded(tokens_prompt_padded)
+ else:
+ self._check_not_padded(tokens_response_unpadded)
+ self._check_not_padded(tokens_prompt_unpadded)
+
+ if self.num_samples is not None:
+ # replicate tokens
+ for i in range(self.num_samples):
+ tokens_obj[:, i].prompt = (
+ tokens_prompt_unpadded
+ if not self.pad_output
+ else tokens_prompt_padded
+ )
+ else:
+ tokens_obj.prompt = (
+ tokens_prompt_unpadded if not self.pad_output else tokens_prompt_padded
+ )
+ with tokens_obj.view(-1) as tokens_obj_flat:
+ if self.pad_output:
+ tokens_obj_flat.response = tokens_response_padded
+ tokens_full_padded = self._cat_tensors(
+ tokens_obj_flat.prompt, tokens_response_padded
+ )
+ tokens_obj_flat.full = tokens_full_padded
+ else:
+ tokens_obj_flat.response = tokens_response_unpadded
+ tokens_full_unpadded = self._cat_tensors(
+ tokens_obj_flat.get("prompt", as_list=True),
+ tokens_response_unpadded,
+ )
+ tokens_obj_flat.full = tokens_full_unpadded
+ tokens_obj.padded = MetaData(self.pad_output)
+ out.set(self.tokens_key, tokens_obj)
+
+ masks_obj = Masks._from_tensordict(out.empty())
+ # self.return_tokens must be True
+ if self.pad_output:
+ # Get "real" attention masks
+ full_attention_mask_padded = tokens_obj.get("full") != self.padding_value
+ masks_obj.all_attention_mask = full_attention_mask_padded.bool()
+ else:
+ # Get "real" attention masks
+ # We can use select to avoid batch-size problems
+ _td = torch.ones_like(
+ out.select(("tokens", "full"))
+ .copy()
+ .rename_key_(("tokens", "full"), "all_attention_mask")
+ ).bool()
+ del _td["tokens"]
+ masks_obj.update(_td)
+ masks_obj.all_assistant_mask = None
+ masks_obj.padded = MetaData(self.pad_output)
+ out.set(self.masks_key, masks_obj)
+
+ if self.return_log_probs:
+ if self.pad_output:
+ log_probs_padded = outputs.get(
+ "logprobs",
+ as_padded_tensor=self.pad_output,
+ padding_value=self.padding_value,
+ padding_side="right",
+ )
+ else:
+ log_probs_list = outputs.get(
+ "logprobs",
+ as_list=True,
+ )
+ self._check_not_padded(log_probs_list)
+ if self.num_samples is None:
+ # TODO: this is not correct, we should use the prompt_logprobs
+ # but they're not returned by vLLM
+ if self.pad_output:
+ prompt_logprobs_padded = request_output_tc.get(
+ "prompt_logprobs",
+ as_padded_tensor=self.pad_output,
+ padding_value=self.padding_value,
+ padding_side="right",
+ )
+ else:
+ prompt_logprobs_list = request_output_tc.get(
+ "prompt_logprobs",
+ as_list=True,
+ )
+ self._check_not_padded(prompt_logprobs_list)
+ log_probs_obj = LogProbs._from_tensordict(out.empty())
+ if self.pad_output:
+ self._check_padded(log_probs_padded)
+ if self.num_samples is None:
+ self._check_padded(prompt_logprobs_padded)
+ log_probs_obj.prompt = prompt_logprobs_padded
+ else:
+ self._check_not_padded(log_probs_list)
+ if self.num_samples is None:
+ self._check_not_padded(prompt_logprobs_list)
+ log_probs_obj.prompt = prompt_logprobs_list
+ with log_probs_obj.view(-1) as log_probs_obj_flat:
+ log_probs_obj_flat.response = (
+ log_probs_padded if self.pad_output else log_probs_list
+ )
+ if self.num_samples is None:
+ if self.pad_output:
+ log_probs_obj_flat.full = self._cat_tensors(
+ log_probs_obj_flat.prompt, log_probs_padded
+ )
+ else:
+ log_probs_obj_flat.full = self._cat_tensors(
+ log_probs_obj_flat.get("prompt", as_list=True),
+ log_probs_list,
+ )
+ else:
+ log_probs_obj_flat.full = None
+ log_probs_obj.padded = MetaData(self.pad_output)
+ out.set(self.log_probs_key, log_probs_obj)
+ return out
+
+ def _logprobs_from_tokens(
+ self,
+ *,
+ response_struct: TensorDictBase | None = None,
+ tokens_full_unpadded: list[torch.Tensor] | None = None,
+ tokens_full_padded: torch.Tensor | None = None,
+ sampling_params: SamplingParams | None = None,
+ out: TensorDictBase | None = None,
+ ) -> TensorDictBase:
+ """Compute log-probs from tokens input."""
+ assert isinstance(
+ response_struct, (TensorDictBase, type(None))
+ ), f"response_struct must be TensorDictBase or None, got {type(response_struct)}"
+ assert isinstance(
+ tokens_full_unpadded, (list, type(None))
+ ), f"tokens_full_unpadded must be list or None, got {type(tokens_full_unpadded)}"
+ assert isinstance(
+ tokens_full_padded, (torch.Tensor, type(None))
+ ), f"tokens_full_padded must be torch.Tensor or None, got {type(tokens_full_padded)}"
+ assert isinstance(
+ sampling_params, (SamplingParams, type(None))
+ ), f"sampling_params must be SamplingParams or None, got {type(sampling_params)}"
+ assert isinstance(
+ out, (TensorDictBase, type(None))
+ ), f"out must be TensorDictBase or None, got {type(out)}"
+
+ # Convert to list format for vLLM
+ if response_struct is not None:
+ tokens_full_padded = response_struct.get(
+ "input_ids",
+ as_padded_tensor=True,
+ padding_value=self.padding_value,
+ padding_side="left",
+ )
+ attention_mask_full_padded = response_struct.get(
+ "attention_mask",
+ as_padded_tensor=True,
+ padding_value=False,
+ padding_side="left",
+ ).bool()
+ attention_mask_full_unpadded = _unpad_tensors(
+ attention_mask_full_padded, attention_mask_full_padded, as_nested=False
+ )
+ elif tokens_full_unpadded is not None:
+ tokens_full_padded = pad_sequence(
+ tokens_full_unpadded,
+ padding_value=self.padding_value,
+ batch_first=True,
+ padding_side="left",
+ )
+ attention_mask_full_unpadded = [
+ t != self.padding_value for t in tokens_full_unpadded
+ ]
+ attention_mask_full_padded = pad_sequence(
+ attention_mask_full_unpadded,
+ padding_value=False,
+ batch_first=True,
+ padding_side="left",
+ )
+ elif tokens_full_padded is not None:
+ attention_mask_full_padded = tokens_full_padded != self.padding_value
+ else:
+ raise ValueError("Either response_struct or tokens must be provided")
+
+ assert isinstance(tokens_full_padded, torch.Tensor)
+ assert isinstance(attention_mask_full_padded, torch.Tensor)
+ if tokens_full_unpadded is None:
+ tokens_full_list = self._to_list(
+ tokens_full_padded, attention_mask_full_padded
+ )
+ else:
+ tokens_full_list = self._to_list(tokens_full_unpadded, None)
+
+ generate_kwargs = {
"sampling_params": sampling_params,
- "prompt_token_ids": input_ids_list,
+ "prompt_token_ids": tokens_full_list,
}
+
+ # Generate with vLLM to get prompt_logprobs
if not self._remote_calls:
- tokens_out = self.model.generate(*args, **kwargs)
+ tokens_out_stuct = self.model.generate(**generate_kwargs)
else:
import ray
- tokens_out = ray.get(self.model.generate.remote(*args, **kwargs))
- tokens_out = _RequestOutput_tc.from_request_output(tokens_out)
- prompt_logprobs = tokens_out.prompt_logprobs
- prompt_logprobs = prompt_logprobs[..., -tokens_response.shape[-1] :]
- padded = tokens_response == self.padding_value
- prompt_logprobs = torch.where(~padded, prompt_logprobs, 0.0)
- out = out.update(tokens_out._tensordict.empty(recurse=True))
- out.set(self.log_prob_key, prompt_logprobs)
- out.set(self.token_response_key, tokens_response)
- inputs = td.select(*self.in_keys, strict=False)
- if inputs.ndim < out.ndim:
- # This happens when n > 1
- inputs = inputs.unsqueeze(-1).expand(out.shape)
- out.update(inputs)
- return out
+ tokens_out_stuct = ray.get(self.model.generate.remote(**generate_kwargs))
+
+ request_output_tc = _RequestOutput_tc.from_request_output(tokens_out_stuct)
- def _get_output_tokens_and_log_probs(self, tokens_out):
- padding_value = self.padding_value
- tokens_out = _RequestOutput_tc.from_request_output(tokens_out)
+ # Extract log-probs from prompt_logprobs
+ if self.pad_output:
+ # For padded case, use all prompt_logprobs
+ log_probs_full_padded = request_output_tc.get(
+ "prompt_logprobs",
+ as_padded_tensor=True,
+ padding_value=0,
+ padding_side="left",
+ )
+
+ # Mask out padding
+ attention_mask_full_padded = tokens_full_padded != self.padding_value
+ log_probs_full_padded = torch.where(
+ attention_mask_full_padded, log_probs_full_padded, 0.0
+ )
+ else:
+ # For unpadded case, extract from each sequence
+ log_probs_full_unpadded = request_output_tc.get(
+ "prompt_logprobs", as_list=True
+ )
+ self._check_not_padded(log_probs_full_unpadded)
+
+ assistant_mask_full_padded = None
+ if response_struct is not None:
+ assistant_mask_full_padded = response_struct.get(
+ "assistant_masks",
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=0,
+ )
+ if assistant_mask_full_padded is not None:
+ assistant_mask_full_padded = assistant_mask_full_padded.bool()
+ if not self.pad_output:
+ assistant_mask_full_unpadded = _unpad_tensors(
+ assistant_mask_full_padded,
+ attention_mask_full_padded,
+ as_nested=False,
+ )
+ else:
+ assistant_mask_full_unpadded = None
+ else:
+ assistant_mask_full_unpadded = None
- # When not generate, we don't want to overwrite this
- tokens_response_td = tokens_out.outputs._tensordict.select(
- "text", "token_ids", "logprobs", strict=False
+ masks_obj = Masks._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
)
if self.pad_output:
- tokens_response_td = tokens_response_td.densify(
- layout=torch.strided
- ).to_padded_tensor(padding=padding_value)
- tokens_response_td.rename_key_("token_ids", "tokens_response")
- tokens_response_td.rename_key_("text", "text_response")
- if not self.pad_output:
- # Then we can safely move the input tokens, but otherwise they
- # may need padding
- tokens_out = tokens_out.select("prompt_token_ids")
- if tokens_out.ndim < tokens_response_td.ndim:
- tokens_out = tokens_out.unsqueeze(1).expand(tokens_response_td.shape)
- tokens_response_td.update(tokens_out).rename_key_(
- "prompt_token_ids", self.token_key
- )
-
- if self.return_log_probs or "logprobs" in tokens_response_td:
- tokens_response_td.rename_key_("logprobs", self.log_prob_key)
- if self.pad_output:
- padded_values = tokens_response_td["tokens_response"] == padding_value
- if padded_values.any():
- lps = tokens_response_td[self.log_prob_key]
- lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0)
- tokens_response_td[self.log_prob_key] = lps
- return tokens_response_td
-
- def _to_list(self, tokens, attention_mask):
- """Converts a tensor of integer in a masked list (of lists) of integers."""
- if isinstance(tokens, torch.Tensor):
- # TODO: make this an ND NonTensorStack
+ self._check_padded(attention_mask_full_padded)
+ masks_obj.all_attention_mask = attention_mask_full_padded.bool()
+ if assistant_mask_full_padded is not None:
+ masks_obj.all_assistant_mask = assistant_mask_full_padded
+ else:
+ self._check_not_padded(attention_mask_full_unpadded)
+ masks_obj.all_attention_mask = attention_mask_full_unpadded
+ if assistant_mask_full_unpadded is not None:
+ masks_obj.all_assistant_mask = assistant_mask_full_unpadded
+ masks_obj.padded = MetaData(self.pad_output)
+ out.set(self.masks_key, masks_obj)
+
+ tokens_obj = Tokens._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if self.pad_output:
+ self._check_padded(tokens_full_padded)
+ tokens_obj.full = tokens_full_padded
+ else:
+ tokens_obj.full = tokens_full_unpadded
+ tokens_obj.response = None
+ tokens_obj.padded = MetaData(self.pad_output)
+ out.set(self.tokens_key, tokens_obj)
+
+ log_probs_obj = LogProbs._from_tensordict(
+ TensorDict(batch_size=out.batch_size).to_lazystack(0)
+ )
+ if self.pad_output:
+ self._check_padded(log_probs_full_padded)
+ log_probs_obj.full = log_probs_full_padded
+ else:
+ self._check_not_padded(log_probs_full_unpadded)
+ log_probs_obj.full = log_probs_full_unpadded
+ log_probs_obj.response = None
+ log_probs_obj.padded = MetaData(self.pad_output)
+ out.set(self.log_probs_key, log_probs_obj)
+
+ return out
+
+ def _to_list(
+ self,
+ tokens_padded: torch.Tensor | list[torch.Tensor],
+ attention_mask_padded: torch.Tensor | None,
+ ) -> list[list[int]]:
+ """Converts a tensor of integers into a masked list (of lists) of integers."""
+ if isinstance(tokens_padded, torch.Tensor):
parent = []
queue = collections.deque()
- if attention_mask is None:
- attention_mask = torch.ones_like(tokens)
- queue.append((tokens, attention_mask.bool(), parent))
+ if attention_mask_padded is None:
+ attention_mask_padded = torch.ones_like(tokens_padded)
+ queue.append((tokens_padded, attention_mask_padded.bool(), parent))
while queue:
- token, amask, _parent = queue.popleft()
- if token.ndim == 1:
- _parent.extend(token[amask].tolist())
+ token_tensor, attention_mask_bool, _parent = queue.popleft()
+ if token_tensor.ndim == 1:
+ _parent.extend(token_tensor[attention_mask_bool].tolist())
else:
- _parent.extend([[] for _ in range(token.shape[0])])
+ _parent.extend([[] for _ in range(token_tensor.shape[0])])
queue.extend(
[
(t, m, local_parent)
- for t, m, local_parent in zip(token, amask, _parent)
+ for t, m, local_parent in zip(
+ token_tensor, attention_mask_bool, _parent
+ )
]
)
- tokens = parent
- return tokens
+ tokens_list = parent
+ elif isinstance(tokens_padded, list):
+ parent = []
+ queue = collections.deque()
+ queue.append((tokens_padded, parent))
+ while queue:
+ tokens_list, _parent = queue.popleft()
+ if isinstance(tokens_list, list) and isinstance(
+ tokens_list[0], (list, torch.Tensor)
+ ):
+ _parent.extend([[] for _ in tokens_list])
+ queue.extend(
+ [
+ (t, local_parent)
+ for t, local_parent in zip(tokens_list, _parent)
+ ]
+ )
+ continue
+ elif isinstance(tokens_list, torch.Tensor):
+ tokens_list = tokens_list.tolist()
+ _parent.extend(tokens_list)
+ tokens_list = parent
+
+ return tokens_list
@_classproperty
def CompletionOutput_tc(cls):
- import vllm
+ if vllm is None:
+ raise ImportError("vllm is required for CompletionOutput_tc")
if hasattr(cls, "_CompletionOutput_tc"):
return cls._CompletionOutput_tc
- CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput)
+ CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput) # type: ignore
cls._CompletionOutput_tc = CompletionOutput_tc
return CompletionOutput_tc
+ def get_dist(
+ self,
+ tensordict: TensorDictBase,
+ tensordict_out: TensorDictBase | None = None,
+ logits_key: NestedKey = "logits",
+ mask_key: NestedKey | None = None,
+ as_padded_tensor: bool | None = None,
+ as_nested_tensor: bool | None = None,
+ padding_value: float | None = None,
+ padding_side: str = "right",
+ layout: torch.layout | None = None,
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution from logits/log-probs with optional masking.
+
+ vLLM does not return logits, so this method is not supported.
+ """
+ raise NotImplementedError(
+ "vLLM does not return logits, so get_dist is not supported"
+ )
+
+ def get_dist_with_prompt_mask(
+ self,
+ tensordict: TensorDictBase,
+ tokens_key: NestedKey = ("tokens", "full"),
+ logits_key: NestedKey = "logits",
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution masked to only include response tokens (exclude prompt).
+
+ vLLM does not return logits, so this method is not supported.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ raise NotImplementedError(
+ "vLLM does not return logits, so get_dist_with_prompt_mask is not supported"
+ )
+
+ def _get_dist_with_assistant_mask(
+ self,
+ tensordict: TensorDictBase,
+ assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"),
+ logits_key: NestedKey = "logits",
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution masked to only include assistant tokens.
+
+ vLLM does not return logits, so this method is not supported.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ raise NotImplementedError(
+ "vLLM does not return logits, so get_dist_with_assistant_mask is not supported"
+ )
+
+ def _get_dist_with_attention_mask(
+ self,
+ tensordict: TensorDictBase,
+ attention_mask_key: NestedKey = ("masks", "all_attention_mask"),
+ logits_key: NestedKey = "logits",
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution masked using attention mask.
+
+ vLLM does not return logits, so this method is not supported.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ raise NotImplementedError(
+ "vLLM does not return logits, so get_dist_with_attention_mask is not supported"
+ )
+
+ def _get_dist_with_custom_mask(
+ self,
+ tensordict: TensorDictBase,
+ mask: torch.Tensor,
+ logits_key: NestedKey = "logits",
+ **kwargs,
+ ) -> D.Distribution:
+ """Get distribution with custom mask.
+
+ vLLM does not return logits, so this method is not supported.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ raise NotImplementedError(
+ "vLLM does not return logits, so get_dist_with_custom_mask is not supported"
+ )
+
+ # Convenience methods for common LLM training scenarios
+ def _get_sft_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
+ """Get distribution suitable for SFT loss (response tokens only).
+
+ vLLM does not return logits, so this method is not supported.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ raise NotImplementedError(
+ "vLLM does not return logits, so get_sft_dist is not supported"
+ )
+
+ def _get_rlhf_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
+ """Get distribution suitable for RLHF loss (assistant tokens only).
+
+ vLLM does not return logits, so this method is not supported.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ raise NotImplementedError(
+ "vLLM does not return logits, so get_rlhf_dist is not supported"
+ )
+
+ def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution:
+ """Get distribution suitable for generic losses (all tokens).
+
+ vLLM does not return logits, so this method is not supported.
+
+ This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy.
+ """
+ raise NotImplementedError(
+ "vLLM does not return logits, so get_generic_dist is not supported"
+ )
+
class _RequestOutput_tc(TensorClass["nocast"]):
+ """TensorClass wrapper for vLLM RequestOutput."""
+
request_id: str
prompt: str
- prompt_token_ids: str
- prompt_logprobs: str
- outputs: str
+ prompt_token_ids: torch.Tensor
+ prompt_logprobs: torch.Tensor
+ outputs: list # type: ignore
finished: str
metrics: str
lora_request: str
encoder_prompt: str
encoder_prompt_token_ids: str
- num_cached_tokens: str
+ num_cached_tokens: torch.Tensor
def __post_init__(self):
CompletionOutput_tc = vLLMWrapper.CompletionOutput_tc
@@ -703,37 +1812,77 @@ def get_logprob(output):
if len(outputs) == 1:
self.outputs = outputs[0]
else:
- self.outputs = maybe_dense_stack(outputs)
- if self.prompt_logprobs is not None:
- self.prompt_logprobs = torch.tensor(
- [
- v[int(tid)].logprob if v is not None else 0.0
- for v, tid in _zip_strict(
- self.prompt_logprobs, self.prompt_token_ids
- )
- ]
- )
- self.prompt_token_ids = torch.as_tensor(self.prompt_token_ids)
- self.num_cached_tokens = torch.as_tensor(self.num_cached_tokens)
+ # Check if we can stack the outputs (they should have the same shape)
+ try:
+ self.outputs = lazy_stack(outputs)
+ except RuntimeError:
+ # If stacking fails (different sizes), keep as list
+ self.outputs = outputs
@classmethod
- def from_request_output(cls, requests):
- out = lazy_stack(
- [
+ def from_request_output(
+ cls, requests: RequestOutput | list[RequestOutput]
+ ) -> _RequestOutput_tc | list[_RequestOutput_tc]:
+ """Create _RequestOutput_tc from vLLM RequestOutput."""
+ # Type assertions
+ assert isinstance(
+ requests, (RequestOutput, list)
+ ), f"requests must be RequestOutput or list, got {type(requests)}"
+
+ # Check if we can stack the outputs
+ try:
+ out = lazy_stack(
+ [
+ cls(
+ request_id=request.request_id,
+ prompt=request.prompt,
+ prompt_token_ids=torch.as_tensor(request.prompt_token_ids),
+ prompt_logprobs=torch.tensor(
+ [
+ v[int(tid)].logprob if v is not None else 0.0
+ for v, tid in _zip_strict(
+ request.prompt_logprobs, request.prompt_token_ids
+ )
+ ]
+ )
+ if request.prompt_logprobs is not None
+ else torch.tensor([]),
+ outputs=request.outputs,
+ finished=request.finished,
+ metrics=request.metrics,
+ lora_request=request.lora_request,
+ encoder_prompt=request.encoder_prompt,
+ encoder_prompt_token_ids=request.encoder_prompt_token_ids,
+ num_cached_tokens=torch.as_tensor(request.num_cached_tokens),
+ )
+ for request in requests
+ ]
+ )
+ return out
+ except RuntimeError:
+ # If stacking fails, return a list of individual _RequestOutput_tc objects
+ return [
cls(
request_id=request.request_id,
prompt=request.prompt,
- prompt_token_ids=request.prompt_token_ids,
- prompt_logprobs=request.prompt_logprobs,
+ prompt_token_ids=torch.as_tensor(request.prompt_token_ids),
+ prompt_logprobs=torch.tensor(
+ [
+ v[int(tid)].logprob if v is not None else 0.0
+ for v, tid in _zip_strict(
+ request.prompt_logprobs, request.prompt_token_ids
+ )
+ ]
+ )
+ if request.prompt_logprobs is not None
+ else torch.tensor([]),
outputs=request.outputs,
finished=request.finished,
metrics=request.metrics,
lora_request=request.lora_request,
encoder_prompt=request.encoder_prompt,
encoder_prompt_token_ids=request.encoder_prompt_token_ids,
- num_cached_tokens=request.num_cached_tokens,
+ num_cached_tokens=torch.as_tensor(request.num_cached_tokens),
)
for request in requests
]
- )
- return out
diff --git a/torchrl/objectives/llm/grpo.py b/torchrl/objectives/llm/grpo.py
index 376b2e8cedd..b415c08c6c7 100644
--- a/torchrl/objectives/llm/grpo.py
+++ b/torchrl/objectives/llm/grpo.py
@@ -5,6 +5,8 @@
from __future__ import annotations
from collections import defaultdict, deque
+from dataclasses import dataclass
+from typing import Literal
import torch
from tensordict import (
@@ -16,16 +18,17 @@
TensorDictParams,
)
from tensordict.nn import (
+ ProbabilisticTensorDictModule,
ProbabilisticTensorDictSequential,
TensorDictModule,
- TensorDictModuleBase,
)
+from tensordict.utils import expand_as_right
from torch import distributions as d
-
from torchrl._utils import logger as torchrl_logger
from torchrl.envs.transforms.transforms import Transform
+from torchrl.modules.llm import LLMWrapperBase
from torchrl.objectives.ppo import ClipPPOLoss
-from torchrl.objectives.utils import _maybe_get_or_select, _reduce, _sum_td_features
+from torchrl.objectives.utils import _reduce, _sum_td_features
class GRPOLossOutput(TensorClass["nocast"]):
@@ -50,7 +53,7 @@ class GRPOLoss(ClipPPOLoss):
loss = -min( weight * advantage, min(max(weight, 1-eps), 1+eps) * advantage)
Args:
- actor_network (ProbabilisticTensorDictSequential): policy operator.
+ actor_network (LLMWrapperBase): policy operator.
.. note::
It is critical to keep your model in eval mode during GRPO training to ensure deterministic behavior and correct
@@ -63,6 +66,15 @@ class GRPOLoss(ClipPPOLoss):
A value of 1 indicates that all importance weights are equal (ideal case). If ESS drops or increases significantly,
it usually indicates a problem with the model configuration, such as a train/eval mode mismatch or a large policy update.
+ .. note::
+ The masking_strategy parameter is crucial for LLM training scenarios. It determines which tokens are included
+ in the loss computation:
+ - "sft": Only response tokens (excludes prompt tokens) - suitable for single-turn conversations
+ - "rlhf": Only assistant tokens (excludes user/system tokens) - suitable for multi-turn conversations
+ - "generic": All valid tokens (excludes padding tokens) - suitable for generic scenarios
+
+ The masking strategy must match the strategy used for advantage computation to avoid shape mismatches.
+
Keyword Args:
clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation.
default: 0.2
@@ -93,34 +105,48 @@ class GRPOLoss(ClipPPOLoss):
kl_to_ref_coeff (float, optional): coefficient for the KL divergence to the reference policy. Defaults to ``None`` (no KL divergence).
kl_to_inference_coeff (float, optional): coefficient for the KL divergence to the inference policy. Defaults to ``None`` (no KL divergence).
device (torch.device, optional): device of the buffers. Defaults to ``None``.
+ masking_strategy (Literal["sft", "rlhf", "generic"], optional): The masking strategy to use for distribution creation.
+ - "sft": Use prompt masking (response tokens only, suitable for single-turn)
+ - "rlhf": Use assistant masking (assistant tokens only, suitable for multi-turn)
+ - "generic": Use attention masking (all valid tokens)
+ Defaults to "sft" since we can't guarantee assistant masks are available.
.. note:: Parameters and buffers from the policy / critic will not be cast to that device to ensure that
the storages match the ones that are passed to other components, such as data collectors.
"""
- actor_network: TensorDictModule
+ actor_network: LLMWrapperBase
critic_network: TensorDictModule
actor_network_params: TensorDictParams
critic_network_params: TensorDictParams
target_actor_network_params: TensorDictParams
target_critic_network_params: TensorDictParams
+ @dataclass
+ class _AcceptedKeys(ClipPPOLoss._AcceptedKeys):
+ """Maintains default values for all configurable tensordict keys.
+
+ This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
+ default values
+ """
+
+ ref_log_probs: NestedKey = ("next", "ref_log_probs", "full")
+
def __init__(
self,
- actor_network: ProbabilisticTensorDictSequential
- | TensorDictModuleBase
- | None = None,
+ actor_network: LLMWrapperBase | None = None,
*,
clip_epsilon: float = 0.2,
entropy_bonus: bool = True,
samples_mc_entropy: int = 1,
entropy_coeff: float = 0.01,
gamma: float | None = None,
- reduction: str = None,
+ reduction: str | None = None,
clip_value: bool | float | None = None,
kl_to_ref_coeff: float | None = None,
kl_to_inference_coeff: float | None = None,
- device: torch.device = None,
+ device: torch.device | None = None,
+ masking_strategy: Literal["sft", "rlhf", "generic"] = "sft",
**kwargs,
):
# Define clipping of the value loss
@@ -143,12 +169,77 @@ def __init__(
)
# We don't want to use the string action but the tokens
self._set_in_keys()
- self.set_keys(sample_log_prob="log_probs", action="tokens_response")
+ self.masking_strategy = masking_strategy
+ # Always use the full tokens for the action
+ self.set_keys(sample_log_prob=("log_probs", "full"), action=("tokens", "full"))
# TODO: make this a buffer
self.kl_to_ref_coeff = kl_to_ref_coeff
self.kl_to_inference_coeff = kl_to_inference_coeff
+ def _get_cur_log_prob(self, tensordict):
+ """Override to use LLM-specific distribution with explicit masking strategy.
+
+ This ensures that the loss is computed with the correct masking strategy,
+ and provides helpful error messages when there are shape mismatches.
+ """
+ if isinstance(
+ self.actor_network,
+ (ProbabilisticTensorDictSequential, ProbabilisticTensorDictModule),
+ ) or hasattr(self.actor_network, "get_dist"):
+ # Use the specified masking strategy
+ # dists are always defined over the whole sequence, so we can re-use the mask as the dist will always
+ # be a MaskedCategorical
+ # TODO: eventually, we want to always use `get_dist` and just pass the key of the mask
+ # Masks should contain: prompt and response masks, assistant, and attention.
+ # Additionally, we should make sure that the masks are properly updated when log-probs is called (using vllm and transformers)
+ # because in some instances it looks like they can be overwritten with None values.
+ if self.masking_strategy == "sft" and hasattr(
+ self.actor_network, "_get_sft_dist"
+ ):
+ dist = self.actor_network._get_sft_dist(tensordict)
+ elif self.masking_strategy == "rlhf" and hasattr(
+ self.actor_network, "_get_rlhf_dist"
+ ):
+ dist = self.actor_network._get_rlhf_dist(tensordict)
+ elif self.masking_strategy == "generic" and hasattr(
+ self.actor_network, "_get_generic_dist"
+ ):
+ dist = self.actor_network._get_generic_dist(tensordict)
+ elif hasattr(self.actor_network, "get_dist"):
+ # Fallback to generic distribution method
+ dist = self.actor_network.get_dist(
+ tensordict,
+ logits_key="logits",
+ )
+ else:
+ raise NotImplementedError(
+ f"Actor network must have get_dist method or the appropriate method for "
+ f"masking strategy '{self.masking_strategy}'."
+ )
+
+ action = tensordict.get(
+ self.tensor_keys.action,
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=-100,
+ )
+ log_prob = dist.log_prob(action)
+ else:
+ raise NotImplementedError(
+ "Only probabilistic modules from tensordict.nn are currently supported. "
+ "If you need to implement a custom logic to retrieve the log-probs (to compute "
+ "the PPO objective) or the distribution (for the PPO entropy), please augment "
+ f"the {type(self).__class__} by implementing your own logic in _get_cur_log_prob."
+ )
+ return log_prob, dist, False
+
def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
+ # Some sanity checks and housekeeping:
+ # - We may not have the tokens yet. If not, we will use the tokenizer of the actor to tokenize the text.
+ # We default to history rather than text because the history will account for multiturn, or multimodal inputs.
+ if self.tensor_keys.action not in tensordict:
+ raise ValueError
+
tensordict = tensordict.copy()
advantage = tensordict.get(
self.tensor_keys.advantage, None, as_padded_tensor=True
@@ -156,15 +247,20 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
log_weight, dist, kl_approx = self._log_weight(
tensordict, adv_shape=advantage.shape[:-1]
)
+ mask = dist.mask
# ESS for logging
with torch.no_grad():
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
# to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
# dispersion.
- lw = log_weight.squeeze()
+ lw = log_weight.squeeze(-1)[mask]
+ batch = mask.sum()
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
- batch = log_weight.shape[0]
+ if advantage.ndim != log_weight.ndim:
+ raise ValueError(
+ f"advantage and log_weight must have the same number of dimensions, got {advantage.ndim=} and {log_weight.ndim=}"
+ )
gain1 = log_weight.exp() * advantage
log_weight_clip = log_weight.clamp(*self._clip_bounds)
@@ -191,14 +287,27 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
if value_clip_fraction is not None:
td_out.set("value_clip_fraction", value_clip_fraction)
- td_out.set("ESS", _reduce(ess, self.reduction) / batch)
+ td_out.set("ESS", _reduce(ess / batch, self.reduction))
td_out = td_out.named_apply(
- lambda name, value: _reduce(value, reduction=self.reduction).squeeze(-1)
+ lambda name, value: _reduce(
+ value, reduction=self.reduction, mask=mask
+ ).squeeze(-1)
if name.startswith("loss_")
else value,
)
if self.kl_to_ref_coeff is not None:
- loss_kl, kl_penalty = self._kl_to_ref(tensordict)
+ # FIXME: parameterize this
+ loss_kl, kl_penalty = self._kl_to_ref(
+ tensordict,
+ mask=mask,
+ dist=dist,
+ ref_log_prob=tensordict.get(
+ self.tensor_keys.ref_log_probs,
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=0.0,
+ ),
+ )
td_out["loss_kl_to_ref"] = loss_kl
td_out["kl_to_ref"] = kl_penalty.detach()
if self.kl_to_inference_coeff is not None:
@@ -206,6 +315,8 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
tensordict,
key=self.tensor_keys.sample_log_prob,
coeff=self.kl_to_inference_coeff,
+ mask=mask,
+ dist=dist,
)
td_out["loss_kl_to_inference"] = loss_kl
td_out["kl_to_inference"] = kl_penalty.detach()
@@ -218,6 +329,8 @@ def _kl_to_ref(
key: NestedKey = ("next", "ref_log_prob"),
ref_log_prob: torch.Tensor | None = None,
coeff: float | None = None,
+ mask: torch.Tensor | None = None,
+ dist: d.Distribution | None = None,
):
if coeff is None:
coeff = self.kl_to_ref_coeff
@@ -226,16 +339,27 @@ def _kl_to_ref(
ref_log_prob = tensordict.get(
key,
as_padded_tensor=True,
- ).squeeze(-1)
+ padding_side="left",
+ padding_value=0.0,
+ )
+ if ref_log_prob is None:
+ raise KeyError(
+ f"Couldn't find the ref log-prob {key} in the input data ({tensordict.keys(True)=})."
+ )
+ ref_log_prob = ref_log_prob.squeeze(-1)
cur_log_prob = tensordict.get("_cur_log_prob")
# TODO: remove this
- assert cur_log_prob.shape == ref_log_prob.shape, (
- cur_log_prob.shape,
- ref_log_prob.shape,
- )
- mask = cur_log_prob != 0
- ref_log_prob = ref_log_prob[mask]
- cur_log_prob = cur_log_prob[mask]
+ if cur_log_prob.shape != ref_log_prob.shape:
+ raise ValueError(
+ f"cur_log_prob and ref_log_prob must have the same shape, got {cur_log_prob.shape=} and {ref_log_prob.shape=}"
+ )
+ if mask is not None:
+ ref_log_prob = torch.where(
+ expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0
+ )
+ cur_log_prob = torch.where(
+ expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
+ )
diff = ref_log_prob - cur_log_prob
kl_penalty = (diff.expm1() - diff).mean()
return coeff * kl_penalty, kl_penalty
@@ -244,12 +368,15 @@ def _log_weight(
self, tensordict: TensorDictBase, adv_shape: torch.Size
) -> tuple[torch.Tensor, d.Distribution, torch.Tensor]:
- prev_log_prob = _maybe_get_or_select(
- tensordict,
+ cur_log_prob, dist, is_composite = self._get_cur_log_prob(tensordict)
+
+ prev_log_prob = tensordict.get(
self.tensor_keys.sample_log_prob,
- adv_shape,
+ as_padded_tensor=True,
+ padding_side="left",
+ padding_value=0.0,
)
- padding_mask = prev_log_prob != 0
+
if prev_log_prob is None:
raise KeyError(
f"Couldn't find the log-prob {self.tensor_keys.sample_log_prob} in the input data."
@@ -259,8 +386,30 @@ def _log_weight(
f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad."
)
- cur_log_prob, dist, is_composite = self._get_cur_log_prob(tensordict)
- cur_log_prob = torch.where(padding_mask, cur_log_prob, 0.0)
+ # Check for shape mismatches and provide helpful error messages
+ if cur_log_prob.shape != prev_log_prob.shape:
+ # Try to provide helpful debugging information
+ error_msg = (
+ f"Shape mismatch detected in GRPOLoss: current log-prob shape {cur_log_prob.shape} "
+ f"!= previous log-prob shape {prev_log_prob.shape}. "
+ f"This usually indicates a mismatch between the masking strategy used for "
+ f"advantage computation and the masking strategy used for loss computation.\n"
+ f"Current masking strategy: '{self.masking_strategy}'\n"
+ f"Possible solutions:\n"
+ f"1. If using RLHF (multi-turn conversations), set masking_strategy='rlhf'\n"
+ f"2. If using SFT (single-turn conversations), set masking_strategy='sft'\n"
+ f"3. If using generic scenarios, set masking_strategy='generic'\n"
+ f"4. Ensure the advantage was computed with the same masking strategy as the loss"
+ )
+ raise ValueError(error_msg)
+
+ attention_mask = dist.mask
+ cur_log_prob = torch.where(
+ expand_as_right(attention_mask, cur_log_prob), cur_log_prob, 0.0
+ )
+ prev_log_prob = torch.where(
+ expand_as_right(attention_mask, prev_log_prob), prev_log_prob, 0.0
+ )
if is_composite:
raise NotImplementedError
@@ -295,7 +444,7 @@ class MCAdvantage(Transform):
Args:
grpo_size (int): Number of trajectories to keep in memory for the advantage computation.
- prompt_key (NestedKey): Key to the prompt in the tensordict. Defaults to "text".
+ prompt_key (NestedKey): Key to the prompt in the tensordict. Defaults to ("text", "prompt").
rewards_key (NestedKey): Key to the rewards in the tensordict. Defaults to ("next", "reward").
advantage_key (NestedKey): Key to the advantage in the tensordict. Defaults to "advantage".
done_key (NestedKey): Key to the done state in the tensordict. Defaults to ("next", "done").
@@ -306,7 +455,7 @@ class MCAdvantage(Transform):
def __init__(
self,
grpo_size: int,
- prompt_key: NestedKey = "text",
+ prompt_key: NestedKey = "query",
rewards_key: NestedKey = ("next", "reward"),
advantage_key: NestedKey = "advantage",
done_key: NestedKey = ("next", "done"),
@@ -327,6 +476,10 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
return tensordict
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
+ if self.verbose:
+ torchrl_logger.info(
+ f"Invoking MCAdvantage.\nData size: {tensordict.shape}.\nCurrent queue size: {len(self.queues)}.\nTotal queue content: {sum(len(q) for q in self.queues.values())}"
+ )
# Tensordict can be any number of dims, but it must contain entire trajectories
if tensordict.ndim == 1:
# Check how many done states we have
@@ -350,6 +503,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
torchrl_logger.info(f"Computing advantage for {prompt=}")
# Cat is the most robust way to combine the trajs
tds = torch.cat(list(self.queues[prompt]), -1)
+ del self.queues[prompt]
# Collect rewards
reward = tds.get(self.rewards_key, as_nested_tensor=True)
reward_mean = reward.values().mean()
@@ -363,7 +517,7 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
elif tensordict.ndim > 2:
# keep the time dim at the end
tensordict = tensordict.flatten(0, -2)
- trajs = tensordict.unbind(-1)
+ trajs = tensordict.unbind(0)
# Iterate over the trajectories
result = []
for traj in trajs:
@@ -372,5 +526,5 @@ def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
continue
result.append(td_out)
if result:
- return torch.cat(result, -1)
+ return torch.cat(result, 0)
return
diff --git a/torchrl/objectives/llm/sft.py b/torchrl/objectives/llm/sft.py
index 7bc83256bf1..1b2568c31f6 100644
--- a/torchrl/objectives/llm/sft.py
+++ b/torchrl/objectives/llm/sft.py
@@ -246,9 +246,9 @@ class _AcceptedKeys:
Defaults to ``"log_probs"``.
"""
- history: NestedKey = ("next", "history")
- ref_log_prob: NestedKey = ("next", "ref_log_prob")
- log_probs: NestedKey = "log_probs"
+ history: NestedKey = ("history", "full")
+ ref_log_prob: NestedKey = ("next", "ref_log_prob", "full")
+ log_probs: NestedKey = ("log_probs", "full")
default_keys = _AcceptedKeys
tensor_keys: _AcceptedKeys
@@ -335,23 +335,28 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# Gather history
history: History = tensordict[self.tensor_keys.history]
- # Apply tokenizer to history and gather mask
- with torch.device(
- self.device
- ) if self.device is not None else contextlib.nullcontext():
- token_struct = history.apply_chat_template(
- tokenizer=self.tokenizer, **self.tokenizer_kwargs
- )
- if "assistant_masks" not in token_struct:
- raise ValueError(
- f"Assistant masks are not present in the token structure: {token_struct=}."
+ # Try to get mask from td
+ token_struct = None
+ assistant_masks = tensordict.get(("masks", "all_assistant_mask"), as_list=True)
+ attention_mask = tensordict.get(("masks", "all_attention_mask"), as_list=True)
+ if assistant_masks is None:
+ # Apply tokenizer to history and gather mask
+ with torch.device(
+ self.device
+ ) if self.device is not None else contextlib.nullcontext():
+ token_struct = history.apply_chat_template(
+ tokenizer=self.tokenizer, **self.tokenizer_kwargs
+ )
+ if "assistant_masks" not in token_struct:
+ raise ValueError(
+ f"Assistant masks are not present in the token structure: {token_struct=}."
+ )
+ assistant_masks = token_struct.get(
+ "assistant_masks",
+ as_list=True,
)
- assistant_masks = token_struct.get(
- "assistant_masks",
- as_list=True,
- )
+ attention_mask = token_struct.get("attention_mask", as_list=True)
assistant_masks = [mask.bool() for mask in assistant_masks]
- attention_mask = token_struct.get("attention_mask", as_list=True)
attention_mask = [mask.bool() for mask in attention_mask]
assistant_masks = [
mask & a_mask for mask, a_mask in zip(assistant_masks, attention_mask)
@@ -359,12 +364,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if not any(mask.any(-1).all() for mask in assistant_masks):
raise ValueError("Some inputs have no valid assistant masks.")
+
input_loss = tensordict.select(self.tensor_keys.history)
- if (
- isinstance(self.tensor_keys.history, tuple)
- and self.tensor_keys.history[0] == "next"
- ):
- input_loss = input_loss["next"]
with torch.device(
self.device
@@ -376,13 +377,19 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
self.tensor_keys.log_probs,
as_list=True,
)
+
# apply mask
if not all(
mask.shape == lp.shape
for mask, lp in _zip_strict(assistant_masks, log_probs)
):
+ if token_struct is not None:
+ suffix = f"Tokens from current template: {[inp.shape for inp in token_struct.get('input_ids', as_padded_tensor=True)]}"
+ else:
+ suffix = ""
raise ValueError(
- f"Assistant masks and log_probs have different shapes: {[mask.shape for mask in assistant_masks]} vs {[lp.shape for lp in log_probs]}. Tokens from current template: {[inp.shape for inp in token_struct.get('input_ids', as_padded_tensor=True)]}"
+ f"Assistant masks and log_probs have different shapes: {[mask.shape for mask in assistant_masks]} vs "
+ f"{[lp.shape for lp in log_probs]}. {suffix}"
)
log_probs_masked = [
@@ -413,7 +420,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
)
if ref_log_probs is None:
raise ValueError(
- "Reference log probs not found in tensordict but kl_to_ref_coeff was set"
+ f"Reference log probs not found in tensordict at key {self.tensor_keys.ref_log_prob} but kl_to_ref_coeff was set. "
+ f"Existing keys in tensordict: {set(tensordict.keys(include_nested=True, leaves_only=True))}"
)
loss_kl, kl_penalty = self._kl_to_ref(
@@ -431,7 +439,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
ref_log_probs = tensordict.get(self.tensor_keys.ref_log_prob, as_list=True)
if ref_log_probs is None:
raise ValueError(
- f"Reference log probs not found at {self.tensor_keys.ref_log_prob=} in tensordict but loss_function is 'minor_sft'"
+ f"Reference log probs not found at {self.tensor_keys.ref_log_prob=} in tensordict with keys {tensordict.keys()} but loss_function is 'minor_sft'"
)
# we need to re-sum ref_log_probs as they are not summed per-sequence
diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py
index e9de0d7753c..152804512e0 100644
--- a/torchrl/objectives/ppo.py
+++ b/torchrl/objectives/ppo.py
@@ -609,7 +609,6 @@ def _get_cur_log_prob(self, tensordict):
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict)
-
is_composite = isinstance(dist, CompositeDistribution)
if is_composite:
diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py
index 59a1897053f..8a95b6979da 100644
--- a/torchrl/objectives/utils.py
+++ b/torchrl/objectives/utils.py
@@ -596,14 +596,31 @@ def new_func(*args, in_dims=in_dims, out_dims=out_dims, **kwargs):
return new_func
-def _reduce(tensor: torch.Tensor, reduction: str) -> float | torch.Tensor:
- """Reduces a tensor given the reduction method."""
+def _reduce(
+ tensor: torch.Tensor, reduction: str, mask: torch.Tensor | None = None
+) -> float | torch.Tensor:
+ """Reduces a tensor given the reduction method.
+
+ Args:
+ tensor (torch.Tensor): The tensor to reduce.
+ reduction (str): The reduction method.
+ mask (torch.Tensor, optional): A mask to apply to the tensor before reducing.
+
+ Returns:
+ float | torch.Tensor: The reduced tensor.
+ """
if reduction == "none":
result = tensor
elif reduction == "mean":
- result = tensor.mean()
+ if mask is not None:
+ result = tensor[mask].mean()
+ else:
+ result = tensor.mean()
elif reduction == "sum":
- result = tensor.sum()
+ if mask is not None:
+ result = tensor[mask].sum()
+ else:
+ result = tensor.sum()
else:
raise NotImplementedError(f"Unknown reduction method {reduction}")
return result
@@ -668,9 +685,20 @@ def _sum_td_features(data: TensorDictBase) -> torch.Tensor:
return data.sum(dim="feature", reduce=True)
-def _maybe_get_or_select(td, key_or_keys, target_shape=None):
+def _maybe_get_or_select(
+ td,
+ key_or_keys,
+ target_shape=None,
+ padding_side: str = "left",
+ padding_value: int = 0,
+):
if isinstance(key_or_keys, (str, tuple)):
- return td.get(key_or_keys, as_padded_tensor=True)
+ return td.get(
+ key_or_keys,
+ as_padded_tensor=True,
+ padding_side=padding_side,
+ padding_value=padding_value,
+ )
result = td.select(*key_or_keys)
if target_shape is not None and result.shape != target_shape:
result.batch_size = target_shape
diff --git a/tutorials/sphinx-tutorials/llm_wrappers.py b/tutorials/sphinx-tutorials/llm_wrappers.py
new file mode 100644
index 00000000000..fe1ae0f9411
--- /dev/null
+++ b/tutorials/sphinx-tutorials/llm_wrappers.py
@@ -0,0 +1,363 @@
+"""
+LLM Wrappers in TorchRL
+=======================
+
+This tutorial demonstrates how to use TorchRL's LLM wrappers for integrating Large Language Models
+into reinforcement learning workflows. TorchRL provides two main wrappers:
+
+- :class:`~torchrl.modules.llm.policies.vLLMWrapper` for vLLM models
+- :class:`~torchrl.modules.llm.policies.TransformersWrapper` for Hugging Face Transformers models
+
+Both wrappers provide a unified API with consistent input/output interfaces using TensorClass objects,
+making them interchangeable in RL environments.
+
+Key Features:
+- Multiple input modes: history, text, or tokens
+- Configurable outputs: text, tokens, masks, and log probabilities
+- TensorClass-based structured outputs
+- Seamless integration with TorchRL's TensorDict framework
+"""
+
+# %%
+# Setup and Imports
+# -----------------
+# First, let's set up the environment and import the necessary modules.
+
+import os
+import warnings
+
+# Suppress warnings for cleaner output
+warnings.filterwarnings("ignore")
+
+# Set vLLM environment variables
+os.environ["VLLM_USE_V1"] = "0"
+
+import torch
+from tensordict import TensorDict
+from torchrl.data.llm import History
+from torchrl.modules.llm.policies import TransformersWrapper, vLLMWrapper
+
+# %%
+# Example 1: vLLM Wrapper with History Input
+# ------------------------------------------
+# The vLLM wrapper is optimized for high-performance inference and is ideal for production environments.
+
+try:
+ from transformers import AutoTokenizer
+ from vllm import LLM
+
+ print("Loading vLLM model...")
+ model = LLM(model="Qwen/Qwen2.5-0.5B")
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+
+ # Create conversation history
+ chats = [
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is the capital of France?"},
+ ],
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is the capital of Canada?"},
+ ],
+ ]
+ history = History.from_chats(chats)
+
+ # Create vLLM wrapper with history input (recommended for RL environments)
+ vllm_wrapper = vLLMWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=True,
+ return_log_probs=True,
+ return_text=True,
+ return_tokens=True,
+ return_masks=True,
+ pad_output=False, # Use False to avoid stacking issues
+ )
+
+ print(f"vLLM wrapper input keys: {vllm_wrapper.in_keys}")
+ print(f"vLLM wrapper output keys: {vllm_wrapper.out_keys}")
+
+ # Process the data
+ data_history = TensorDict(history=history, batch_size=(2,))
+ result = vllm_wrapper(data_history)
+
+ print("vLLM Results:")
+ print(f"Generated responses: {result['text'].response}")
+ print(
+ f"Response tokens shape: {result['tokens'].response.shape if result['tokens'].response is not None else 'None'}"
+ )
+ print(f"Log probabilities available: {result['log_probs'].response is not None}")
+
+except ImportError:
+ print("vLLM not available, skipping vLLM example")
+
+# %%
+# Example 2: Transformers Wrapper with History Input
+# --------------------------------------------------
+# The Transformers wrapper provides more flexibility and is great for research and development.
+
+try:
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ print("\nLoading Transformers model...")
+ transformers_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
+ transformers_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
+
+ # Create Transformers wrapper with same interface
+ transformers_wrapper = TransformersWrapper(
+ transformers_model,
+ tokenizer=transformers_tokenizer,
+ input_mode="history",
+ generate=True,
+ return_log_probs=True,
+ return_text=True,
+ return_tokens=True,
+ return_masks=True,
+ pad_output=True, # Transformers typically use padded outputs
+ generate_kwargs={"max_new_tokens": 50},
+ )
+
+ print(f"Transformers wrapper input keys: {transformers_wrapper.in_keys}")
+ print(f"Transformers wrapper output keys: {transformers_wrapper.out_keys}")
+
+ # Process the same data
+ result_tf = transformers_wrapper(data_history)
+
+ print("Transformers Results:")
+ print(f"Generated responses: {result_tf['text'].response}")
+ print(
+ f"Response tokens shape: {result_tf['tokens'].response.shape if result_tf['tokens'].response is not None else 'None'}"
+ )
+ print(f"Log probabilities available: {result_tf['log_probs'].response is not None}")
+
+except ImportError:
+ print("Transformers not available, skipping Transformers example")
+
+# %%
+# Example 3: Text Input Mode
+# --------------------------
+# Both wrappers support direct text input for simpler use cases.
+
+try:
+ # Create text input data
+ prompts = ["The capital of France is", "The capital of Canada is"]
+ data_text = TensorDict(text=prompts, batch_size=(2,))
+
+ # vLLM with text input
+ vllm_text_wrapper = vLLMWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="text",
+ generate=True,
+ return_text=True,
+ return_tokens=True,
+ pad_output=False,
+ )
+
+ result_vllm_text = vllm_text_wrapper(data_text)
+ print("\nvLLM Text Input Results:")
+ print(f"Generated text: {result_vllm_text['text'].response}")
+
+ # Transformers with text input
+ transformers_text_wrapper = TransformersWrapper(
+ transformers_model,
+ tokenizer=transformers_tokenizer,
+ input_mode="text",
+ generate=True,
+ return_text=True,
+ return_tokens=True,
+ pad_output=True,
+ generate_kwargs={"max_new_tokens": 20},
+ )
+
+ result_tf_text = transformers_text_wrapper(data_text)
+ print("Transformers Text Input Results:")
+ print(f"Generated text: {result_tf_text['text'].response}")
+
+except NameError:
+ print("Models not loaded, skipping text input example")
+
+# %%
+# Example 4: Log Probabilities Only Mode
+# --------------------------------------
+# Both wrappers can compute log probabilities without generating new tokens.
+
+try:
+ # vLLM log-probs only
+ vllm_logprobs_wrapper = vLLMWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="history",
+ generate=False, # Only compute log-probs
+ return_log_probs=True,
+ return_text=True,
+ return_tokens=True,
+ pad_output=False,
+ )
+
+ result_vllm_lp = vllm_logprobs_wrapper(data_history)
+ print("\nvLLM Log Probabilities:")
+ print(
+ f"Prompt log-probs shape: {result_vllm_lp['log_probs'].prompt.shape if result_vllm_lp['log_probs'].prompt is not None else 'None'}"
+ )
+
+ # Transformers log-probs only
+ transformers_logprobs_wrapper = TransformersWrapper(
+ transformers_model,
+ tokenizer=transformers_tokenizer,
+ input_mode="history",
+ generate=False,
+ return_log_probs=True,
+ return_text=True,
+ return_tokens=True,
+ pad_output=True,
+ )
+
+ result_tf_lp = transformers_logprobs_wrapper(data_history)
+ print("Transformers Log Probabilities:")
+ print(
+ "Prompt log-probs shape: {result_tf_lp['log_probs'].prompt.shape if result_tf_lp['log_probs'].prompt is not None else 'None'}"
+ )
+
+except NameError:
+ print("Models not loaded, skipping log-probs example")
+
+# %%
+# Example 5: TensorClass Structure Exploration
+# -------------------------------------------
+# Let's explore the structured outputs provided by both wrappers.
+
+try:
+ # Get a result from vLLM wrapper
+ result = vllm_wrapper(data_history)
+
+ print("\nTensorClass Structure Analysis:")
+ print("=" * 50)
+
+ # Explore Text TensorClass
+ print("\nText TensorClass:")
+ print(f" Fields: {list(result['text'].__class__.__annotations__.keys())}")
+ print(f" Prompt: {result['text'].prompt}")
+ print(f" Response: {result['text'].response}")
+ print(f" Full: {result['text'].full}")
+ print(f" Padded: {result['text'].padded}")
+
+ # Explore Tokens TensorClass
+ print("\nTokens TensorClass:")
+ print(f" Fields: {list(result['tokens'].__class__.__annotations__.keys())}")
+ print(
+ f" Prompt tokens shape: {result['tokens'].prompt.shape if result['tokens'].prompt is not None else 'None'}"
+ )
+ print(
+ f" Response tokens shape: {result['tokens'].response.shape if result['tokens'].response is not None else 'None'}"
+ )
+ print(
+ f" Full tokens shape: {result['tokens'].full.shape if result['tokens'].full is not None else 'None'}"
+ )
+
+ # Explore LogProbs TensorClass
+ print("\nLogProbs TensorClass:")
+ print(f" Fields: {list(result['log_probs'].__class__.__annotations__.keys())}")
+ print(
+ f" Prompt log-probs shape: {result['log_probs'].prompt.shape if result['log_probs'].prompt is not None else 'None'}"
+ )
+ print(
+ f" Response log-probs shape: {result['log_probs'].response.shape if result['log_probs'].response is not None else 'None'}"
+ )
+
+ # Explore Masks TensorClass
+ print("\nMasks TensorClass:")
+ print(f" Fields: {list(result['masks'].__class__.__annotations__.keys())}")
+ print(
+ f" Attention mask shape: {result['masks'].all_attention_mask.shape if result['masks'].all_attention_mask is not None else 'None'}"
+ )
+ print(
+ f" Assistant mask shape: {result['masks'].all_assistant_mask.shape if result['masks'].all_assistant_mask is not None else 'None'}"
+ )
+
+except NameError:
+ print("Models not loaded, skipping structure exploration")
+
+# %%
+# Example 6: Error Handling and Validation
+# ----------------------------------------
+# Both wrappers provide clear error messages for invalid inputs.
+
+print("\nError Handling Examples:")
+print("=" * 30)
+
+# Example of missing required key
+try:
+ wrapper = vLLMWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="tokens",
+ input_key="tokens",
+ )
+ result = wrapper(TensorDict(batch_size=(2,))) # Missing tokens key
+except (ValueError, NameError) as e:
+ print(f"Expected error for missing key: {e}")
+
+# Example of invalid input mode
+try:
+ wrapper = vLLMWrapper(
+ model,
+ tokenizer=tokenizer,
+ input_mode="invalid_mode", # Invalid mode
+ )
+except ValueError as e:
+ print(f"Expected error for invalid input mode: {e}")
+
+# %%
+# Example 7: RL Environment Integration
+# ------------------------------------
+# The wrappers are designed to work seamlessly with TorchRL environments.
+
+print("\nRL Environment Integration:")
+print("=" * 35)
+
+# Simulate an RL environment step
+try:
+ # Create a simple environment state
+ env_state = TensorDict(
+ {
+ "history": history,
+ "action_mask": torch.ones(2, 1000), # Example action mask
+ "reward": torch.zeros(2),
+ "done": torch.zeros(2, dtype=torch.bool),
+ },
+ batch_size=(2,),
+ )
+
+ # Use the wrapper as a policy
+ action_output = vllm_wrapper(env_state)
+
+ print("Environment integration successful!")
+ print(f"Generated actions: {action_output['text'].response}")
+ print(
+ f"Action log probabilities: {action_output['log_probs'].response is not None}"
+ )
+
+except NameError:
+ print("Models not loaded, skipping RL integration example")
+
+# %%
+# Conclusion
+# ----------
+# TorchRL's LLM wrappers provide a unified interface for integrating Large Language Models
+# into reinforcement learning workflows. Key benefits include:
+#
+# 1. **Consistent API**: Both vLLM and Transformers wrappers share the same interface
+# 2. **Flexible Input Modes**: Support for history, text, and token inputs
+# 3. **Structured Outputs**: TensorClass-based outputs for easy data handling
+# 4. **RL Integration**: Seamless integration with TorchRL's TensorDict framework
+# 5. **Configurable Outputs**: Selective return of text, tokens, masks, and log probabilities
+#
+# The wrappers are designed to be interchangeable, allowing you to switch between
+# different LLM backends without changing your RL code.
+
+print("\n" + "=" * 60)
+print("Tutorial completed successfully!")
+print("=" * 60)