Skip to content

Added MessagesDataloader so we can just use messages in our datasets rather than tokenized inputs #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e08f7fe
added raw_untokenized_texts to the batch
SeanKski Jun 10, 2025
8ec9f69
added intial messages dataloader
SeanKski Jun 14, 2025
a6edade
pushed tokenizer to messages class
SeanKski Jun 16, 2025
62b0237
added messages to the dataset prep
SeanKski Jun 16, 2025
455ce9a
added different messages option to preprocesser
SeanKski Jun 16, 2025
b0fe610
updated README to use messages rather than prompts
SeanKski Jun 16, 2025
e4ac438
added tests
SeanKski Jun 16, 2025
854a4da
fixed bugs with messages collator
SeanKski Jun 17, 2025
1ade679
Merge main into seank/chat_messages
SeanKski Jun 17, 2025
a9b2d8f
adding back changes to callback
SeanKski Jun 17, 2025
d0c81fe
vllm hotfix
SeanKski Jun 17, 2025
830a131
removed raw_untokenized texts from default reward batch
SeanKski Jun 17, 2025
576e153
updated local_grpo and local_ppo yamls
SeanKski Jun 18, 2025
448ab98
Merge branch 'main' into seank/chat_messages
SeanKski Jun 18, 2025
f212b03
ruff is a cruel and hard to please master, but looks like i've finall…
SeanKski Jun 18, 2025
6d2b4a3
added messages_dataset_to_mds script
SeanKski Jun 23, 2025
93556d5
added option for metadata in messages dataset
SeanKski Jun 23, 2025
4f3ae65
updated readme and made pyright happy
SeanKski Jun 23, 2025
3c5283f
undid accidental commit to generation_utils
SeanKski Jun 23, 2025
f241fb0
removed metadata from batch for now
SeanKski Jun 23, 2025
cc546a7
fixed messages collating based on brandon's comment
SeanKski Jun 23, 2025
db61308
Merge branch 'main' into seank/chat_messages
SeanKski Jun 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Here is an end-to-end workflow of performing data preparation for training along

### Data preparation

Below is the set of commands to run to prepare datasets into the appropriate Mosaic Data Shard (MDS) format, which is a pre-tokenized version of the data, that we will use for training.
Below is the set of commands to run to prepare datasets into the appropriate Mosaic Data Shard (MDS) format, which is either a pre-tokenized version of the data or the raw chat-message data, that we will use for training.

Below is the command to prepare preference data -- which can be used for reward model or offline RL (e.g. DPO) training:

Expand All @@ -60,26 +60,26 @@ python data/unified_tokenize_dataset.py --dataset_name allenai/ultrafeedback_bin
--split train_prefs
```

Below is the command to prepare prompt data -- which can be used for online RL (e.g. PPO) training:
Below is the command to prepare single-turn message data -- which can be used for online RL (e.g. PPO) training:

<!--pytest.mark.skip-->
```bash
cd scripts
python data/unified_tokenize_dataset.py --dataset_name allenai/ultrafeedback_binarized_cleaned \
--local_dir prompt_data \
--dataset_type single_prompt \
--local_dir single_message_data \
--dataset_type single_message \
--tokenizer_name meta-llama/Llama-3.1-8B-Instruct \
--split train_prefs
```

To further enable online RL with [verifiable rewards](https://arxiv.org/abs/2411.15124) you can use the following command:
To further enable online RL with [verifiable rewards](https://arxiv.org/abs/2411.15124) you can use the following command to prepare the chat-message data and their corresponding verifiable answers:

<!--pytest.mark.skip-->
```bash
cd scripts
python data/unified_tokenize_dataset.py --dataset_name <hf_dataset_name> \
--local_dir verifiable_data \
--dataset_type verifiable_answers \
--dataset_type messages_with_answer \
--tokenizer_name meta-llama/Llama-3.1-8B-Instruct \
--split train \
```
Expand Down Expand Up @@ -129,7 +129,7 @@ Below is the command to run Online PPO training:
```bash
composer llm-foundry/scripts/train/train.py \
compose-rl/yamls/local_ppo.yaml \
train_loader.dataset.local=/compose-rl/scripts/prompt_data/ \
train_loader.dataset.local=/compose-rl/scripts/single_message_data/ \
train_loader.dataset.split=train_prefs
```

Expand Down
14 changes: 13 additions & 1 deletion compose_rl/algorithms/online/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,12 @@ def _get_next_iter_prompts(self):
# Explode the batch into multiple batches for each generation
for _ in range(self.generations_per_prompt):
# For keys that do not require additional processing
if key in ['prompt_len', 'verified_answer', 'prompt_id']:
if key in [
'prompt_len',
'verified_answer',
'prompt_id',
'messages',
]:
curr_values.append(batch[key])
continue

Expand Down Expand Up @@ -678,6 +683,13 @@ def _get_next_iter_prompts(self):
else:
if key == 'verified_answer':
ret_batch[key] = list(flatten(curr_values))
elif key == 'messages':
# the messages should be [num_batches_per_update, batch_size, num_turns]
# need to flatten this to [num_batches_per_update * batch_size, num_turns]
ret_batch[key] = [
message_chain for batch in curr_values
for message_chain in batch
]
else:
# this is an edge case that we will not hit currently, but just handling it as needed
ret_batch[key] = curr_values
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,17 @@ def vllm_generate(
# Pull the necessary variables from the batch and self
cur_device = batch['prompt'].device
prompt_tokens = batch['prompt']
messages = batch['messages']

prompt_all_gather_start_time = time.time()
prompt_and_messages_all_gather_start_time = time.time()

# TODO: (seank) update this to use gather_object(dst=0) rather than all_gather_object
all_batched_prompts = dist.all_gather_object(prompt_tokens)
all_batched_messages = dist.all_gather_object(messages)
batch_sizes = [len(batch) for batch in all_batched_prompts]

log.info(
f'took : {time.time() - prompt_all_gather_start_time} to gather prompts',
f'took : {time.time() - prompt_and_messages_all_gather_start_time} to gather prompts and messages',
)
all_prompts = [prompt for batch in all_batched_prompts for prompt in batch]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def create_vllm_engines(
log.info(f'vllm: {num_gpus=}, {num_engines=}')

vllm_engines.append(
LLMRayActor.options(
LLMRayActor.options( # type: ignore
num_cpus=num_gpus,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
Expand Down
4 changes: 4 additions & 0 deletions compose_rl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
)
from compose_rl.data.dataloader import (
build_finegrained_preference_dataloader,
build_messages_dataloader,
build_pairwise_preference_dataloader,
build_prompt_dataloader,
)
from compose_rl.data.messages_data import messages_dataset_collate_fn
from compose_rl.data.preference_data import (
finegrained_preference_dataset_collate_fn,
pairwise_preference_dataset_collate_fn,
Expand All @@ -19,10 +21,12 @@
__all__ = [
'build_pairwise_preference_dataloader',
'build_finegrained_preference_dataloader',
'build_messages_dataloader',
'build_prompt_dataloader',
'DummyDataset',
'finegrained_preference_dataset_collate_fn',
'MinibatchRolloutBuffer',
'pairwise_preference_dataset_collate_fn',
'prompt_dataset_collate_fn',
'messages_dataset_collate_fn',
]
19 changes: 17 additions & 2 deletions compose_rl/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer

from compose_rl.data.messages_data import (
MessagesStreamingDataset,
messages_dataset_collate_fn,
)
from compose_rl.data.preference_data import (
FinegrainedPreferenceStreamingDataset,
PairwisePreferenceStreamingDataset,
Expand All @@ -25,6 +29,7 @@
'build_finegrained_preference_dataloader',
'build_pairwise_preference_dataloader',
'build_prompt_dataloader',
'build_messages_dataloader',
]


Expand Down Expand Up @@ -71,10 +76,15 @@ def build_preference_dataloader(
streams = None
if streams_dict is not None:
streams = [Stream(**stream) for stream in streams_dict.values()]
if issubclass(
dataset_cls,
MessagesStreamingDataset,
) and 'tokenizer' not in dataset_cfg:
dataset_cfg['tokenizer'] = tokenizer

streaming_dataset = dataset_cls(
streams=streams,
batch_size=device_batch_size,
streams=streams, # type: ignore
batch_size=device_batch_size, # type: ignore
**dataset_cfg,
)

Expand Down Expand Up @@ -111,3 +121,8 @@ def build_preference_dataloader(
PromptStreamingDataset,
prompt_dataset_collate_fn,
)

build_messages_dataloader = generate_dataloader_builder(
MessagesStreamingDataset,
messages_dataset_collate_fn,
)
168 changes: 168 additions & 0 deletions compose_rl/data/messages_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2025 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

"""Build a prompt dataset and dataloader for training."""

import logging
from typing import Any, TypeAlias

import torch
from streaming import StreamingDataset
from transformers import (
AutoTokenizer,
DataCollatorForLanguageModeling,
PreTrainedTokenizerBase,
)

import compose_rl.utils as utils

log = logging.getLogger(__name__)

Messages: TypeAlias = list[dict[str, str]]


def messages_dataset_collate_fn(
tokenizer: PreTrainedTokenizerBase,
max_seq_len: int,
batch: list[dict[str, Any]],
) -> dict[str, Any]:
"""Collator for messages data.

Args:
batch (List[Dict[str, Any]]): A list of data samples to collate.
tokenizer (PreTrainedTokenizer): The model's tokenizer.
max_seq_len (int): The maximum sequence length of the model.
"""
if tokenizer.pad_token_id is None:
raise ValueError('Tokenizer must have a PAD token.')

ref_collate_fn = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
mlm_probability=0.0,
)

keys = batch[0].keys()
collated_batch: dict[str, Any] = {}
for key in keys:
cur_values = [item[key] for item in batch]
if key in ['prompt_len']:
collated_batch[key] = torch.stack(cur_values).squeeze(dim=1)
elif key == 'prompt_id':
collated_batch[key] = torch.tensor(cur_values)
elif key in ['verified_answer']:
collated_batch[key] = list( # pyright: ignore[reportGeneralTypeIssues]
utils.flatten(cur_values),
)
elif key == 'messages':
collated_batch[key] = cur_values
elif key == 'prompt':
collated_batch[key] = ref_collate_fn(cur_values)['input_ids']
else:
raise ValueError(f'Invalid key: {key}')

collated_batch['prompt_attention_mask'] = torch.logical_not(
torch.eq(collated_batch['prompt'],
tokenizer.pad_token_id), # type: ignore
)

return collated_batch


class MessagesStreamingDataset(StreamingDataset):
"""Dataloader for streaming in messages and converting to prompts."""

def __init__(
self,
max_gen_len: int,
max_seq_len: int,
tokenizer: str | PreTrainedTokenizerBase,
**kwargs: dict[str, Any],
):
self.max_gen_len = max_gen_len
self.max_seq_len = max_seq_len
if isinstance(tokenizer, str):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
else:
self.tokenizer = tokenizer
super().__init__(**kwargs)

def _validate_messages(self, messages: Messages) -> bool:
"""Validates the message.

A valid message is a list of dictionaries with
a 'role' key and a 'content' key.

Args:
messages (Messages): The messages to validate.

Returns:
bool: True if the messages are valid, False otherwise.
"""
if not isinstance(messages, list):
return False
for message in messages:
if not isinstance(message, dict):
return False
if 'role' not in message:
return False
if 'content' not in message:
return False
if not isinstance(message['content'], str):
return False
return True

def _tokenize_messages(self, messages: Messages) -> torch.Tensor:
if not self._validate_messages(messages):
raise ValueError(f'Invalid messages received. Got: {messages=}')
return torch.tensor(
self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
),
dtype=torch.int64,
)

# How to process a sample
def __getitem__(self, idx: int) -> dict[str, Any]:
"""Get an item from StreamingDataset at a given index.

Args:
idx (int): the index where we fetch the data in the StreamingDataset.
"""
sample = super().__getitem__(idx)
messages = sample['messages']
prompt: torch.Tensor = self._tokenize_messages(messages)

# TODO (bcui): Maybe add in an option to truncate a prompt by a given length?
if len(prompt) + self.max_gen_len > self.max_seq_len:
truncate_len = len(prompt) + self.max_gen_len - self.max_seq_len
log.info(f'Truncating prompt by: {truncate_len}')
prompt = prompt[:-truncate_len]

prompt_len = torch.Tensor([len(prompt)]).to(dtype=torch.int64)
# Send the prompt id along with prompt data
item_dict = {
'prompt_id': idx,
'prompt': prompt,
'prompt_len': prompt_len,
'messages': messages,
}

verified_answer = sample.get('verified_answer', None)
if verified_answer:
if isinstance(verified_answer, str):
_answer = verified_answer
else:
try:
_answer = verified_answer.decode('utf-8', errors='strict')
except UnicodeDecodeError as e:
log.error(
f'Failed to decode verifed_answer with error: {e}',
)
_answer = ''

item_dict['verified_answer'] = _answer # type: ignore

return item_dict
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ hf_ppo_lm = "compose_rl.algorithms.online:ComposerHFPolicyLM"
pairwise_preference = "compose_rl.data:build_pairwise_preference_dataloader"
finegrained_preference = "compose_rl.data:build_finegrained_preference_dataloader"
prompt = "compose_rl.data:build_prompt_dataloader"
messages = "compose_rl.data:build_messages_dataloader"

[project.entry-points."llmfoundry_callbacks_with_config"]
offline_rl = "compose_rl.algorithms.offline:ReferencePolicyCallback"
Expand Down
Loading
Loading