-
Notifications
You must be signed in to change notification settings - Fork 10
Added MessagesDataloader so we can just use messages
in our datasets rather than tokenized inputs
#92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
SeanKski
wants to merge
22
commits into
main
Choose a base branch
from
seank/chat_messages
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Added MessagesDataloader so we can just use messages
in our datasets rather than tokenized inputs
#92
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
e08f7fe
added raw_untokenized_texts to the batch
SeanKski 8ec9f69
added intial messages dataloader
SeanKski a6edade
pushed tokenizer to messages class
SeanKski 62b0237
added messages to the dataset prep
SeanKski 455ce9a
added different messages option to preprocesser
SeanKski b0fe610
updated README to use messages rather than prompts
SeanKski e4ac438
added tests
SeanKski 854a4da
fixed bugs with messages collator
SeanKski 1ade679
Merge main into seank/chat_messages
SeanKski a9b2d8f
adding back changes to callback
SeanKski d0c81fe
vllm hotfix
SeanKski 830a131
removed raw_untokenized texts from default reward batch
SeanKski 576e153
updated local_grpo and local_ppo yamls
SeanKski 448ab98
Merge branch 'main' into seank/chat_messages
SeanKski f212b03
ruff is a cruel and hard to please master, but looks like i've finall…
SeanKski 6d2b4a3
added messages_dataset_to_mds script
SeanKski 93556d5
added option for metadata in messages dataset
SeanKski 4f3ae65
updated readme and made pyright happy
SeanKski 3c5283f
undid accidental commit to generation_utils
SeanKski f241fb0
removed metadata from batch for now
SeanKski cc546a7
fixed messages collating based on brandon's comment
SeanKski db61308
Merge branch 'main' into seank/chat_messages
SeanKski File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright 2025 MosaicML ComposeRL authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Build a prompt dataset and dataloader for training.""" | ||
|
||
import logging | ||
from typing import Any, TypeAlias | ||
|
||
import torch | ||
from streaming import StreamingDataset | ||
from transformers import ( | ||
AutoTokenizer, | ||
DataCollatorForLanguageModeling, | ||
PreTrainedTokenizerBase, | ||
) | ||
|
||
import compose_rl.utils as utils | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
Messages: TypeAlias = list[dict[str, str]] | ||
|
||
|
||
def messages_dataset_collate_fn( | ||
tokenizer: PreTrainedTokenizerBase, | ||
max_seq_len: int, | ||
batch: list[dict[str, Any]], | ||
) -> dict[str, Any]: | ||
"""Collator for messages data. | ||
|
||
Args: | ||
batch (List[Dict[str, Any]]): A list of data samples to collate. | ||
tokenizer (PreTrainedTokenizer): The model's tokenizer. | ||
max_seq_len (int): The maximum sequence length of the model. | ||
""" | ||
if tokenizer.pad_token_id is None: | ||
raise ValueError('Tokenizer must have a PAD token.') | ||
|
||
ref_collate_fn = DataCollatorForLanguageModeling( | ||
tokenizer=tokenizer, | ||
mlm=False, | ||
mlm_probability=0.0, | ||
) | ||
|
||
keys = batch[0].keys() | ||
collated_batch: dict[str, Any] = {} | ||
for key in keys: | ||
cur_values = [item[key] for item in batch] | ||
if key in ['prompt_len']: | ||
collated_batch[key] = torch.stack(cur_values).squeeze(dim=1) | ||
elif key == 'prompt_id': | ||
collated_batch[key] = torch.tensor(cur_values) | ||
elif key in ['verified_answer']: | ||
collated_batch[key] = list( # pyright: ignore[reportGeneralTypeIssues] | ||
utils.flatten(cur_values), | ||
) | ||
elif key == 'messages': | ||
collated_batch[key] = cur_values | ||
elif key == 'prompt': | ||
collated_batch[key] = ref_collate_fn(cur_values)['input_ids'] | ||
else: | ||
raise ValueError(f'Invalid key: {key}') | ||
|
||
collated_batch['prompt_attention_mask'] = torch.logical_not( | ||
torch.eq(collated_batch['prompt'], | ||
tokenizer.pad_token_id), # type: ignore | ||
) | ||
|
||
return collated_batch | ||
|
||
|
||
class MessagesStreamingDataset(StreamingDataset): | ||
"""Dataloader for streaming in messages and converting to prompts.""" | ||
|
||
def __init__( | ||
self, | ||
max_gen_len: int, | ||
max_seq_len: int, | ||
tokenizer: str | PreTrainedTokenizerBase, | ||
**kwargs: dict[str, Any], | ||
): | ||
self.max_gen_len = max_gen_len | ||
self.max_seq_len = max_seq_len | ||
if isinstance(tokenizer, str): | ||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) | ||
else: | ||
self.tokenizer = tokenizer | ||
super().__init__(**kwargs) | ||
|
||
def _validate_messages(self, messages: Messages) -> bool: | ||
"""Validates the message. | ||
|
||
A valid message is a list of dictionaries with | ||
a 'role' key and a 'content' key. | ||
|
||
Args: | ||
messages (Messages): The messages to validate. | ||
|
||
Returns: | ||
bool: True if the messages are valid, False otherwise. | ||
""" | ||
if not isinstance(messages, list): | ||
return False | ||
for message in messages: | ||
if not isinstance(message, dict): | ||
return False | ||
if 'role' not in message: | ||
return False | ||
if 'content' not in message: | ||
return False | ||
if not isinstance(message['content'], str): | ||
return False | ||
return True | ||
|
||
def _tokenize_messages(self, messages: Messages) -> torch.Tensor: | ||
if not self._validate_messages(messages): | ||
raise ValueError(f'Invalid messages received. Got: {messages=}') | ||
return torch.tensor( | ||
self.tokenizer.apply_chat_template( | ||
messages, | ||
tokenize=True, | ||
add_generation_prompt=True, | ||
), | ||
dtype=torch.int64, | ||
) | ||
|
||
# How to process a sample | ||
def __getitem__(self, idx: int) -> dict[str, Any]: | ||
"""Get an item from StreamingDataset at a given index. | ||
|
||
Args: | ||
idx (int): the index where we fetch the data in the StreamingDataset. | ||
""" | ||
sample = super().__getitem__(idx) | ||
messages = sample['messages'] | ||
prompt: torch.Tensor = self._tokenize_messages(messages) | ||
|
||
# TODO (bcui): Maybe add in an option to truncate a prompt by a given length? | ||
if len(prompt) + self.max_gen_len > self.max_seq_len: | ||
truncate_len = len(prompt) + self.max_gen_len - self.max_seq_len | ||
log.info(f'Truncating prompt by: {truncate_len}') | ||
prompt = prompt[:-truncate_len] | ||
|
||
prompt_len = torch.Tensor([len(prompt)]).to(dtype=torch.int64) | ||
# Send the prompt id along with prompt data | ||
item_dict = { | ||
'prompt_id': idx, | ||
'prompt': prompt, | ||
'prompt_len': prompt_len, | ||
'messages': messages, | ||
} | ||
|
||
verified_answer = sample.get('verified_answer', None) | ||
if verified_answer: | ||
if isinstance(verified_answer, str): | ||
_answer = verified_answer | ||
else: | ||
try: | ||
_answer = verified_answer.decode('utf-8', errors='strict') | ||
except UnicodeDecodeError as e: | ||
log.error( | ||
f'Failed to decode verifed_answer with error: {e}', | ||
) | ||
_answer = '' | ||
|
||
item_dict['verified_answer'] = _answer # type: ignore | ||
|
||
return item_dict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.