From e08f7fe674d12322065aff1c3754eeffc08c4310 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Tue, 10 Jun 2025 22:19:05 +0000 Subject: [PATCH 01/19] added raw_untokenized_texts to the batch --- compose_rl/ppo/reward_manager.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/ppo/reward_manager.py b/compose_rl/ppo/reward_manager.py index b01aabb4..499a6251 100644 --- a/compose_rl/ppo/reward_manager.py +++ b/compose_rl/ppo/reward_manager.py @@ -486,6 +486,7 @@ def _create_batch( base_batch['generated_lens'], 'seq_lens': base_batch['seq_lens'], + 'raw_untokenized_texts': raw_untokenized_texts, } else: raise TypeError( From 8ec9f695331b40492b877136d96bc2291fd644c2 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Sat, 14 Jun 2025 00:02:36 +0000 Subject: [PATCH 02/19] added intial messages dataloader --- compose_rl/data/__init__.py | 6 ++ compose_rl/data/dataloader.py | 10 ++ compose_rl/data/messages_data.py | 155 +++++++++++++++++++++++++++++++ compose_rl/ppo/callback.py | 2 +- pyproject.toml | 1 + 5 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 compose_rl/data/messages_data.py diff --git a/compose_rl/data/__init__.py b/compose_rl/data/__init__.py index af33fa1e..4f742d3b 100644 --- a/compose_rl/data/__init__.py +++ b/compose_rl/data/__init__.py @@ -5,6 +5,7 @@ build_finegrained_preference_dataloader, build_pairwise_preference_dataloader, build_prompt_dataloader, + build_messages_dataloader, ) from compose_rl.data.preference_data import ( finegrained_preference_dataset_collate_fn, @@ -21,14 +22,19 @@ prepare_math_prompt, remove_boxed, ) +from compose_rl.data.messages_data import ( + messages_dataset_collate_fn, +) __all__ = [ 'build_pairwise_preference_dataloader', 'build_finegrained_preference_dataloader', + 'build_messages_dataloader', 'build_prompt_dataloader', 'extract_gsm8k_answer', 'finegrained_preference_dataset_collate_fn', 'pairwise_preference_dataset_collate_fn', + 'messages_dataset_collate_fn', 'prepare_gsm8k_prompt', 'prompt_dataset_collate_fn', 'extract_math_answer', diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 2fd37184..93c88289 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -20,11 +20,16 @@ PromptStreamingDataset, prompt_dataset_collate_fn, ) +from compose_rl.data.messages_data import ( + MessagesStreamingDataset, + messages_dataset_collate_fn, +) __all__ = [ 'build_finegrained_preference_dataloader', 'build_pairwise_preference_dataloader', 'build_prompt_dataloader', + 'build_messages_dataloader', ] @@ -111,3 +116,8 @@ def build_preference_dataloader( PromptStreamingDataset, prompt_dataset_collate_fn, ) + +build_messages_dataloader = generate_dataloader_builder( + MessagesStreamingDataset, + messages_dataset_collate_fn, +) \ No newline at end of file diff --git a/compose_rl/data/messages_data.py b/compose_rl/data/messages_data.py new file mode 100644 index 00000000..acc85868 --- /dev/null +++ b/compose_rl/data/messages_data.py @@ -0,0 +1,155 @@ +# Copyright 2025 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +"""Build a prompt dataset and dataloader for training.""" + +import logging +from typing import Any + +import torch +from streaming import StreamingDataset +from transformers import ( + DataCollatorForLanguageModeling, + PreTrainedTokenizerBase, +) + +import compose_rl.utils as utils + +log = logging.getLogger(__name__) + + +def messages_dataset_collate_fn( + tokenizer: PreTrainedTokenizerBase, + max_seq_len: int, + batch: list[dict[str, Any]], +) -> dict[str, torch.Tensor]: + """Collator for messages data. + + Args: + batch (List[Dict[str, Any]]): A list of data samples to collate. + tokenizer (PreTrainedTokenizer): The model's tokenizer. + max_seq_len (int): The maximum sequence length of the model. + """ + if tokenizer.pad_token_id is None: + raise ValueError('Tokenizer must have a PAD token.') + + ref_collate_fn = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + mlm_probability=0.0, + ) + + keys = batch[0].keys() + collated_batch: dict[str, torch.Tensor] = {} + for key in keys: + cur_values = [item[key] for item in batch] + if key in ['prompt_len']: + collated_batch[key] = torch.stack(cur_values).squeeze(dim=1) + continue + if key == 'prompt_id': + collated_batch[key] = torch.tensor(cur_values) + continue + if key in ['verified_answer']: + collated_batch[key] = list( # pyright: ignore[reportGeneralTypeIssues] + utils.flatten(cur_values), + ) + continue + + collated_batch[key] = ref_collate_fn(cur_values)['input_ids'] + + collated_batch['prompt_attention_mask'] = torch.logical_not( + torch.eq(collated_batch['prompt'], + tokenizer.pad_token_id), # type: ignore + ) + return collated_batch + + +class MessagesStreamingDataset(StreamingDataset): + """Dataloader for streaming in messages and converting to prompts.""" + + def __init__( + self, + max_gen_len: int, + max_seq_len: int, + tokenizer: PreTrainedTokenizerBase, + **kwargs: dict[str, Any], + ): + self.max_gen_len = max_gen_len + self.max_seq_len = max_seq_len + self.tokenizer = tokenizer + super().__init__(**kwargs) + + def _validate_messages(self, messages: list[dict[str, str]]) -> bool: + """Validate the messages. A valid message is a list of dictionaries with the following keys: + - role: str + - content: str + + Args: + messages (list[dict[str, str]]): The messages to validate. + Returns: + bool: True if the messages are valid, False otherwise. + """ + if not isinstance(messages, list): + return False + for message in messages: + if not isinstance(message, dict): + return False + if 'role' not in message: + return False + if 'content' not in message: + return False + if not isinstance(message['content'], str): + return False + return True + + def _tokenize_messages(self, messages: list[str]) -> torch.Tensor: + if not self._validate_messages(messages): + raise ValueError(f'Invalid messages received. Got: {messages=}') + return torch.Tensor(self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + )) + + # How to process a sample + def __getitem__(self, idx: int) -> dict[str, Any]: + """Get an item from StreamingDataset at a given index. + + Args: + idx (int): the index where we fetch the data in the StreamingDataset. + """ + sample = super().__getitem__(idx) + messages = sample['messages'] + prompt: torch.Tensor = self._tokenize_messages(messages) + + # TODO (bcui): Maybe add in an option to truncate a prompt by a given length? + if len(prompt) + self.max_gen_len > self.max_seq_len: + truncate_len = len(prompt) + self.max_gen_len - self.max_seq_len + log.info(f'Truncating prompt by: {truncate_len}') + prompt = prompt[:-truncate_len] + + prompt_len = torch.Tensor([len(prompt)]).to(dtype=torch.int64) + # Send the prompt id along with prompt data + item_dict = { + 'prompt_id': idx, + 'prompt': prompt, + 'prompt_len': prompt_len, + 'messages': messages, + } + + verified_answer = sample.get('verified_answer', None) + if verified_answer: + if isinstance(verified_answer, str): + _answer = verified_answer + else: + try: + _answer = verified_answer.decode('utf-8', errors='strict') + except UnicodeDecodeError as e: + log.error( + f'Failed to decode verifed_answer with error: {e}', + ) + _answer = '' + + item_dict['verified_answer'] = _answer # type: ignore + + return item_dict diff --git a/compose_rl/ppo/callback.py b/compose_rl/ppo/callback.py index 71c8efdc..f7323d4d 100644 --- a/compose_rl/ppo/callback.py +++ b/compose_rl/ppo/callback.py @@ -966,7 +966,7 @@ def _resolve_outputs( def _log_generations_to_logger(self, state: State): # Gather all prompts, generations, prompt_ids and rewards from all ranks prompts_and_gens = list( - chain(*dist.all_gather_object(self.prompts_and_gens)), + chain(*dist.all_gather_object(self.prompts_and_gens)), ) prompt_ids_rewards_and_answers = list( chain(*dist.all_gather_object(self.prompt_ids_rewards_and_answers)), diff --git a/pyproject.toml b/pyproject.toml index 640720a3..ca30e3cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ hf_critic_free_lm = "compose_rl.ppo:ComposerHFCriticFreePolicyModel" pairwise_preference = "compose_rl.data:build_pairwise_preference_dataloader" finegrained_preference = "compose_rl.data:build_finegrained_preference_dataloader" prompt = "compose_rl.data:build_prompt_dataloader" +messages = "compose_rl.data:build_messages_dataloader" [project.entry-points."llmfoundry_callbacks_with_config"] dpo = "compose_rl.dpo:DPOCallback" From a6edade7722f98f85b03f2769e9338b151850555 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 16 Jun 2025 20:26:51 +0000 Subject: [PATCH 03/19] pushed tokenizer to messages class --- compose_rl/data/dataloader.py | 3 +++ compose_rl/data/messages_data.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 93c88289..e7256194 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -77,6 +77,9 @@ def build_preference_dataloader( if streams_dict is not None: streams = [Stream(**stream) for stream in streams_dict.values()] + if isinstance(dataset_cls, MessagesStreamingDataset) and 'tokenizer' not in dataset_cfg: + dataset_cfg['tokenizer'] = tokenizer + streaming_dataset = dataset_cls( streams=streams, batch_size=device_batch_size, diff --git a/compose_rl/data/messages_data.py b/compose_rl/data/messages_data.py index acc85868..134a1f45 100644 --- a/compose_rl/data/messages_data.py +++ b/compose_rl/data/messages_data.py @@ -9,6 +9,7 @@ import torch from streaming import StreamingDataset from transformers import ( + AutoTokenizer, DataCollatorForLanguageModeling, PreTrainedTokenizerBase, ) @@ -71,12 +72,15 @@ def __init__( self, max_gen_len: int, max_seq_len: int, - tokenizer: PreTrainedTokenizerBase, + tokenizer: str | PreTrainedTokenizerBase, **kwargs: dict[str, Any], ): self.max_gen_len = max_gen_len self.max_seq_len = max_seq_len - self.tokenizer = tokenizer + if isinstance(tokenizer, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + else: + self.tokenizer = tokenizer super().__init__(**kwargs) def _validate_messages(self, messages: list[dict[str, str]]) -> bool: From 62b0237fdf66f574ecd7a1c234b0cca69c5eccd5 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 16 Jun 2025 22:10:19 +0000 Subject: [PATCH 04/19] added messages to the dataset prep --- scripts/data/unified_tokenize_dataset.py | 36 ++++++++++++++++++++---- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/scripts/data/unified_tokenize_dataset.py b/scripts/data/unified_tokenize_dataset.py index 2bb9d70f..a01996d2 100644 --- a/scripts/data/unified_tokenize_dataset.py +++ b/scripts/data/unified_tokenize_dataset.py @@ -75,10 +75,22 @@ def __iter__(self) -> Iterator[dict[str, bytes]]: result = self._process_single_prompt_sample(sample) if result is not None: yield result + elif self.dataset_type == 'single_message': + result = self._process_single_prompt_sample(sample, return_messages=True) + if result is not None: + # delete the prompt from the results since it's not needed + result.pop('prompt') + yield result elif self.dataset_type == 'verifiable_answers': result = self._process_verifiable_answer_sample(sample) if result is not None: yield result + elif self.dataset_type == 'messages_with_answer': + result = self._process_verifiable_answer_sample(sample, return_messages=True) + if result is not None: + # delete the prompt from the results since it's not needed + result.pop('prompt') + yield result elif self.dataset_type == 'classifier': yield self._process_classifier_sample(sample) @@ -105,7 +117,7 @@ def _process_preference_sample(self, sample: Any): 'rejected': np.asarray(curr_rejected).tobytes(), } - def _process_single_prompt_sample(self, sample: Any): + def _process_single_prompt_sample(self, sample: Any, return_messages: bool = False): """Process a prompt sample. Args: @@ -118,6 +130,7 @@ def _process_single_prompt_sample(self, sample: Any): 'content': f'Can you summarize the following content in 50 words or less: {prompt}', }] + encoded_prompt = self.tokenizer.apply_chat_template( messages, tokenize=True, @@ -127,7 +140,10 @@ def _process_single_prompt_sample(self, sample: Any): if len(encoded_prompt) > self.max_length: return None - return {'prompt': np.asarray(encoded_prompt).tobytes()} + output = {'prompt': np.asarray(encoded_prompt).tobytes()} + if return_messages: + output['messages'] = messages + return output def _process_classifier_sample(self, sample: Any): """A dummy process a classifier sample. @@ -169,7 +185,7 @@ def _get_processing_fn_from_dataset(self): return prompt_fn, answer_fn - def _process_verifiable_answer_sample(self, sample: Any): + def _process_verifiable_answer_sample(self, sample: Any, return_messages: bool = False): """Process a prompt sample and extract the answer. This function is currently hard-coded for the GSM8K dataset. @@ -207,10 +223,13 @@ def _process_verifiable_answer_sample(self, sample: Any): ) return None - return { + output = { 'prompt': np.asarray(encoded_prompt).tobytes(), 'verified_answer': verified_answer, } + if return_messages: + output['messages'] = messages + return output def _check_for_encoding(self, sample: str) -> bool: """Check if a sample is encodable by streaming. @@ -247,7 +266,7 @@ def main( hashes: list[str], splits: list[str], tokenizer_name: str, - dataset_type: Literal['preference', 'single_prompt', 'verifiable_answers'], + dataset_type: Literal['preference', 'single_prompt', 'verifiable_answers', 'messages_with_answer'], max_length: int = 2048, subset: str | None = None, token: str | None = None, @@ -260,6 +279,9 @@ def main( 'single_prompt': { 'prompt': 'bytes', }, + 'single_message': { + 'messages': 'json', + }, 'verifiable_answers': { 'prompt': 'bytes', 'verified_answer': 'str', @@ -268,6 +290,10 @@ def main( 'input': 'bytes', 'label': 'bytes', }, + 'messages_with_answer': { + 'messages': 'json', + 'verified_answer': 'str', + }, }[dataset_type] tokenizer = AutoTokenizer.from_pretrained( From 455ce9a3d52720c472495d2dbbe45e5cc54e52bf Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 16 Jun 2025 22:18:10 +0000 Subject: [PATCH 05/19] added different messages option to preprocesser --- scripts/data/unified_tokenize_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/data/unified_tokenize_dataset.py b/scripts/data/unified_tokenize_dataset.py index a01996d2..4df802ef 100644 --- a/scripts/data/unified_tokenize_dataset.py +++ b/scripts/data/unified_tokenize_dataset.py @@ -266,7 +266,7 @@ def main( hashes: list[str], splits: list[str], tokenizer_name: str, - dataset_type: Literal['preference', 'single_prompt', 'verifiable_answers', 'messages_with_answer'], + dataset_type: Literal['preference', 'single_prompt', 'single_message', 'verifiable_answers', 'messages_with_answer', 'classifier'], max_length: int = 2048, subset: str | None = None, token: str | None = None, @@ -364,6 +364,8 @@ def main( 'single_prompt', 'classifier', 'verifiable_answers', + 'messages_with_answer', + 'single_message', ], required=True, help='Type of dataset to process', From b0fe610c3aeeac8dc989356ac86811caf3edc23f Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 16 Jun 2025 22:43:10 +0000 Subject: [PATCH 06/19] updated README to use messages rather than prompts --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 299c2200..501be8ca 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ Here is an end-to-end workflow of performing data preparation for training along ### Data preparation -Below is the set of commands to run to prepare datasets into the appropriate Mosaic Data Shard (MDS) format, which is a pre-tokenized version of the data, that we will use for training. +Below is the set of commands to run to prepare datasets into the appropriate Mosaic Data Shard (MDS) format, which is either a pre-tokenized version of the data or the raw chat-message data, that we will use for training. Below is the command to prepare preference data -- which can be used for reward model or offline RL (e.g. DPO) training: @@ -60,26 +60,26 @@ python data/unified_tokenize_dataset.py --dataset_name allenai/ultrafeedback_bin --split train_prefs ``` -Below is the command to prepare prompt data -- which can be used for online RL (e.g. PPO) training: +Below is the command to prepare single-turn message data -- which can be used for online RL (e.g. PPO) training: ```bash cd scripts python data/unified_tokenize_dataset.py --dataset_name allenai/ultrafeedback_binarized_cleaned \ ---local_dir prompt_data \ ---dataset_type single_prompt \ +--local_dir single_message_data \ +--dataset_type single_message \ --tokenizer_name meta-llama/Llama-3.1-8B-Instruct \ --split train_prefs ``` -To further enable online RL with [verifiable rewards](https://arxiv.org/abs/2411.15124) you can use the following command: +To further enable online RL with [verifiable rewards](https://arxiv.org/abs/2411.15124) you can use the following command to prepare the chat-message data and their corresponding verifiable answers: ```bash cd scripts python data/unified_tokenize_dataset.py --dataset_name \ --local_dir verifiable_data \ ---dataset_type verifiable_answers \ +--dataset_type messages_with_answer \ --tokenizer_name meta-llama/Llama-3.1-8B-Instruct \ --split train \ ``` @@ -129,7 +129,7 @@ Below is the command to run Online PPO training: ```bash composer llm-foundry/scripts/train/train.py \ compose-rl/yamls/local_ppo.yaml \ -train_loader.dataset.local=/compose-rl/scripts/prompt_data/ \ +train_loader.dataset.local=/compose-rl/scripts/single_message_data/ \ train_loader.dataset.split=train_prefs ``` From e4ac438db75e814d665c229afa681f059fe9557b Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 16 Jun 2025 22:43:17 +0000 Subject: [PATCH 07/19] added tests --- tests/common/__init__.py | 2 ++ tests/common/datasets.py | 20 ++++++++++++++++++++ tests/test_ppo.py | 21 +++++++++++++++------ 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/tests/common/__init__.py b/tests/common/__init__.py index 1230c821..1d588da3 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -6,6 +6,7 @@ PairwisePreference, PromptDataset, VerifiablePromptDataset, + VerifiableMessagesDataset, ) from tests.common.markers import device, world_size @@ -14,6 +15,7 @@ 'FineGrainedPreference', 'PromptDataset', 'VerifiablePromptDataset', + 'VerifiableMessagesDataset', 'device', 'world_size', ] diff --git a/tests/common/datasets.py b/tests/common/datasets.py index 98da559e..f58aee0f 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -89,3 +89,23 @@ def __getitem__(self, index: int): 'prompt_len': torch.Tensor([self.prompt_len]).to(torch.int64), 'verified_answer': '1', } + + +class VerifiableMessagesDataset(Dataset): + + def __init__(self, size: int = 8, prompt_len: int = 5): + self.size = size + self.prompt_len = prompt_len + + def __len__(self): + return self.size + + def __getitem__(self, index: int): + messages = [{'role': 'user', 'content': 'F' * self.prompt_len}] # bit of a hack, but it works + mock_prompt = torch.ones((len(messages[0]['content']),)).int() + return { + 'messages': messages, + 'prompt': mock_prompt, + 'prompt_len': torch.Tensor([len(messages[0]['content'])]).to(torch.int64), + 'verified_answer': 'Paris', + } \ No newline at end of file diff --git a/tests/test_ppo.py b/tests/test_ppo.py index f6942fb5..0ea733bf 100644 --- a/tests/test_ppo.py +++ b/tests/test_ppo.py @@ -17,14 +17,14 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.models.gpt2 import GPT2LMHeadModel -from compose_rl.data import prompt_dataset_collate_fn +from compose_rl.data import prompt_dataset_collate_fn, messages_dataset_collate_fn from compose_rl.ppo import ( ComposerHFPolicyModel, ComposerMosaicPolicy, PPOCallback, ) from compose_rl.ppo.modeling_hf import ComposerHFPolicy -from tests.common import PromptDataset, VerifiablePromptDataset, world_size +from tests.common import PromptDataset, VerifiablePromptDataset, VerifiableMessagesDataset, world_size def test_hf_ppo_model_construction( @@ -76,19 +76,28 @@ def test_hf_ppo_policy_construction( @pytest.mark.parametrize('model_type', ['mpt', 'hf']) -@pytest.mark.parametrize('dataset_type', ['prompt', 'verifiable_prompt']) +@pytest.mark.parametrize('dataset_type', ['prompt', 'verifiable_prompt', 'verifiable_messages']) def test_model_forward( tiny_gpt2_tokenizer: PreTrainedTokenizerBase, model_type: str, dataset_type: str, ): prompt_len = 10 - data_class = PromptDataset if dataset_type == 'prompt' else VerifiablePromptDataset - dataset = data_class(prompt_len=prompt_len) + if dataset_type == 'prompt': + dataset = PromptDataset(prompt_len=prompt_len) + dataset_collator = prompt_dataset_collate_fn + elif dataset_type == 'verifiable_prompt': + dataset = VerifiablePromptDataset(prompt_len=prompt_len) + dataset_collator = prompt_dataset_collate_fn + elif dataset_type == 'verifiable_messages': + dataset = VerifiableMessagesDataset(prompt_len=prompt_len) + dataset_collator = messages_dataset_collate_fn + else: + raise ValueError(f'Unknown dataset type: {dataset_type}') dataloader = DataLoader( dataset, collate_fn=partial( - prompt_dataset_collate_fn, + dataset_collator, tiny_gpt2_tokenizer, 32, ), From 854a4da5714e170bb46364fede084354a77c2f5a Mon Sep 17 00:00:00 2001 From: SeanKski Date: Tue, 17 Jun 2025 20:24:26 +0000 Subject: [PATCH 08/19] fixed bugs with messages collator --- compose_rl/data/dataloader.py | 4 ++-- compose_rl/data/messages_data.py | 32 +++++++++++++++++++------------- compose_rl/ppo/callback.py | 6 +++++- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index e7256194..601f07b9 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -5,6 +5,7 @@ from functools import partial from typing import Any, Callable +import logging from streaming import Stream, StreamingDataLoader, StreamingDataset from torch.utils.data import DataLoader @@ -76,8 +77,7 @@ def build_preference_dataloader( streams = None if streams_dict is not None: streams = [Stream(**stream) for stream in streams_dict.values()] - - if isinstance(dataset_cls, MessagesStreamingDataset) and 'tokenizer' not in dataset_cfg: + if issubclass(dataset_cls, MessagesStreamingDataset) and 'tokenizer' not in dataset_cfg: dataset_cfg['tokenizer'] = tokenizer streaming_dataset = dataset_cls( diff --git a/compose_rl/data/messages_data.py b/compose_rl/data/messages_data.py index 134a1f45..300f70d1 100644 --- a/compose_rl/data/messages_data.py +++ b/compose_rl/data/messages_data.py @@ -4,7 +4,7 @@ """Build a prompt dataset and dataloader for training.""" import logging -from typing import Any +from typing import Any, TypeAlias import torch from streaming import StreamingDataset @@ -18,6 +18,8 @@ log = logging.getLogger(__name__) +Messages: TypeAlias = list[dict[str, str]] + def messages_dataset_collate_fn( tokenizer: PreTrainedTokenizerBase, @@ -39,29 +41,33 @@ def messages_dataset_collate_fn( mlm=False, mlm_probability=0.0, ) - + keys = batch[0].keys() - collated_batch: dict[str, torch.Tensor] = {} + collated_batch: dict[str, list[str] | torch.Tensor | list[Messages]] = {} for key in keys: cur_values = [item[key] for item in batch] if key in ['prompt_len']: collated_batch[key] = torch.stack(cur_values).squeeze(dim=1) continue - if key == 'prompt_id': + elif key == 'prompt_id': collated_batch[key] = torch.tensor(cur_values) - continue - if key in ['verified_answer']: + elif key in ['verified_answer']: collated_batch[key] = list( # pyright: ignore[reportGeneralTypeIssues] utils.flatten(cur_values), ) + elif key == 'messages': + collated_batch[key] = cur_values continue - - collated_batch[key] = ref_collate_fn(cur_values)['input_ids'] + elif key == 'prompt': + collated_batch[key] = ref_collate_fn(cur_values)['input_ids'] + else: + raise ValueError(f'Invalid key: {key}') collated_batch['prompt_attention_mask'] = torch.logical_not( torch.eq(collated_batch['prompt'], tokenizer.pad_token_id), # type: ignore ) + return collated_batch @@ -83,13 +89,13 @@ def __init__( self.tokenizer = tokenizer super().__init__(**kwargs) - def _validate_messages(self, messages: list[dict[str, str]]) -> bool: + def _validate_messages(self, messages: Messages) -> bool: """Validate the messages. A valid message is a list of dictionaries with the following keys: - role: str - content: str Args: - messages (list[dict[str, str]]): The messages to validate. + messages (Messages): The messages to validate. Returns: bool: True if the messages are valid, False otherwise. """ @@ -106,14 +112,14 @@ def _validate_messages(self, messages: list[dict[str, str]]) -> bool: return False return True - def _tokenize_messages(self, messages: list[str]) -> torch.Tensor: + def _tokenize_messages(self, messages: Messages) -> torch.Tensor: if not self._validate_messages(messages): raise ValueError(f'Invalid messages received. Got: {messages=}') - return torch.Tensor(self.tokenizer.apply_chat_template( + return torch.tensor(self.tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, - )) + ), dtype=torch.int64) # How to process a sample def __getitem__(self, idx: int) -> dict[str, Any]: diff --git a/compose_rl/ppo/callback.py b/compose_rl/ppo/callback.py index f7323d4d..6e972237 100644 --- a/compose_rl/ppo/callback.py +++ b/compose_rl/ppo/callback.py @@ -638,7 +638,7 @@ def _get_next_iter_prompts(self): # Explode the batch into multiple batches for each generation for _ in range(self.generations_per_prompt): # For keys that do not require additional processing - if key in ['prompt_len', 'verified_answer', 'prompt_id']: + if key in ['prompt_len', 'verified_answer', 'prompt_id', 'messages']: curr_values.append(batch[key]) continue @@ -668,6 +668,10 @@ def _get_next_iter_prompts(self): else: if key == 'verified_answer': ret_batch[key] = list(utils.flatten(curr_values)) + elif key == 'messages': + # the messages should be [num_batches_per_update, batch_size, num_turns] + # need to flatten this to [num_batches_per_update * batch_size, num_turns] + ret_batch[key] = [message_chain for batch in curr_values for message_chain in batch] else: # this is an edge case that we will not hit currently, but just handling it as needed ret_batch[key] = curr_values From a9b2d8faef7836e76ac03656ef2dc6a8e1aa0e89 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Tue, 17 Jun 2025 21:14:09 +0000 Subject: [PATCH 09/19] adding back changes to callback --- compose_rl/algorithms/online/callback.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 0eef5b4a..46586154 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -648,7 +648,7 @@ def _get_next_iter_prompts(self): # Explode the batch into multiple batches for each generation for _ in range(self.generations_per_prompt): # For keys that do not require additional processing - if key in ['prompt_len', 'verified_answer', 'prompt_id']: + if key in ['prompt_len', 'verified_answer', 'prompt_id', 'messages']: curr_values.append(batch[key]) continue @@ -678,6 +678,10 @@ def _get_next_iter_prompts(self): else: if key == 'verified_answer': ret_batch[key] = list(flatten(curr_values)) + elif key == 'messages': + # the messages should be [num_batches_per_update, batch_size, num_turns] + # need to flatten this to [num_batches_per_update * batch_size, num_turns] + ret_batch[key] = [message_chain for batch in curr_values for message_chain in batch] else: # this is an edge case that we will not hit currently, but just handling it as needed ret_batch[key] = curr_values From d0c81feb2b9ae715d10115cfbaaf5b19b5750f30 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Tue, 17 Jun 2025 23:11:18 +0000 Subject: [PATCH 10/19] vllm hotfix --- compose_rl/algorithms/online/generation_utils/vllm_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index 21f134a6..b844480d 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -249,7 +249,7 @@ def create_vllm_engines( tokenizer_revision=revision, # type: ignore trust_remote_code=True, # type: ignore worker_extension_cls= # type: ignore - 'compose_rl.utils.vllm_utils.WorkerWrap', + 'compose_rl.algorithms.online.generation_utils.vllm_utils.WorkerWrap', tensor_parallel_size=tensor_parallel_size, # type: ignore enforce_eager=enforce_eager, # type: ignore dtype='bfloat16', # type: ignore From 830a1319e1a23d75b2d62557a9bcea5f45150c06 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Tue, 17 Jun 2025 23:55:48 +0000 Subject: [PATCH 11/19] removed raw_untokenized texts from default reward batch --- compose_rl/algorithms/online/reward_manager.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/compose_rl/algorithms/online/reward_manager.py b/compose_rl/algorithms/online/reward_manager.py index f0407261..49ceb848 100644 --- a/compose_rl/algorithms/online/reward_manager.py +++ b/compose_rl/algorithms/online/reward_manager.py @@ -488,8 +488,7 @@ def _create_batch( 'generated_lens': base_batch['generated_lens'], 'seq_lens': - base_batch['seq_lens'], - 'raw_untokenized_texts': raw_untokenized_texts, + base_batch['seq_lens'] } else: raise TypeError( From 576e153d908472c02027024262a0213e9c3627ac Mon Sep 17 00:00:00 2001 From: SeanKski Date: Wed, 18 Jun 2025 00:16:57 +0000 Subject: [PATCH 12/19] updated local_grpo and local_ppo yamls --- yamls/local_grpo.yaml | 2 +- yamls/local_ppo.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/yamls/local_grpo.yaml b/yamls/local_grpo.yaml index 504128e5..9ad1ea02 100644 --- a/yamls/local_grpo.yaml +++ b/yamls/local_grpo.yaml @@ -114,7 +114,7 @@ tokenizer: truncation: true train_loader: - name: prompt + name: messages dataset: # local: # local path split: train diff --git a/yamls/local_ppo.yaml b/yamls/local_ppo.yaml index d89f6701..4ee648ad 100644 --- a/yamls/local_ppo.yaml +++ b/yamls/local_ppo.yaml @@ -112,7 +112,7 @@ tokenizer: truncation: true train_loader: - name: prompt + name: messages dataset: # local: # local path split: train From f212b03e85267eb83e6590955f77db2c1b618395 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Wed, 18 Jun 2025 18:40:19 +0000 Subject: [PATCH 13/19] ruff is a cruel and hard to please master, but looks like i've finally done it --- compose_rl/algorithms/online/callback.py | 12 +++++- .../online/generation_utils/vllm_utils.py | 2 +- .../algorithms/online/reward_manager.py | 2 +- compose_rl/data/__init__.py | 4 +- compose_rl/data/dataloader.py | 20 +++++----- compose_rl/data/messages_data.py | 31 ++++++++------- scripts/data/unified_tokenize_dataset.py | 39 ++++++++++++++----- scripts/launch_composer_ray.py | 6 ++- tests/common/__init__.py | 2 +- tests/common/datasets.py | 19 ++++++--- tests/test_online.py | 17 ++++++-- 11 files changed, 104 insertions(+), 50 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 46586154..78fb1970 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -648,7 +648,12 @@ def _get_next_iter_prompts(self): # Explode the batch into multiple batches for each generation for _ in range(self.generations_per_prompt): # For keys that do not require additional processing - if key in ['prompt_len', 'verified_answer', 'prompt_id', 'messages']: + if key in [ + 'prompt_len', + 'verified_answer', + 'prompt_id', + 'messages', + ]: curr_values.append(batch[key]) continue @@ -681,7 +686,10 @@ def _get_next_iter_prompts(self): elif key == 'messages': # the messages should be [num_batches_per_update, batch_size, num_turns] # need to flatten this to [num_batches_per_update * batch_size, num_turns] - ret_batch[key] = [message_chain for batch in curr_values for message_chain in batch] + ret_batch[key] = [ + message_chain for batch in curr_values + for message_chain in batch + ] else: # this is an edge case that we will not hit currently, but just handling it as needed ret_batch[key] = curr_values diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index b844480d..3349ca0a 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -239,7 +239,7 @@ def create_vllm_engines( log.info(f'vllm: {num_gpus=}, {num_engines=}') vllm_engines.append( - LLMRayActor.options( + LLMRayActor.options( # type: ignore num_cpus=num_gpus, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, diff --git a/compose_rl/algorithms/online/reward_manager.py b/compose_rl/algorithms/online/reward_manager.py index 49ceb848..635dbe9f 100644 --- a/compose_rl/algorithms/online/reward_manager.py +++ b/compose_rl/algorithms/online/reward_manager.py @@ -488,7 +488,7 @@ def _create_batch( 'generated_lens': base_batch['generated_lens'], 'seq_lens': - base_batch['seq_lens'] + base_batch['seq_lens'], } else: raise TypeError( diff --git a/compose_rl/data/__init__.py b/compose_rl/data/__init__.py index 573140b8..5031dd10 100644 --- a/compose_rl/data/__init__.py +++ b/compose_rl/data/__init__.py @@ -7,16 +7,16 @@ ) from compose_rl.data.dataloader import ( build_finegrained_preference_dataloader, + build_messages_dataloader, build_pairwise_preference_dataloader, build_prompt_dataloader, - build_messages_dataloader, ) +from compose_rl.data.messages_data import messages_dataset_collate_fn from compose_rl.data.preference_data import ( finegrained_preference_dataset_collate_fn, pairwise_preference_dataset_collate_fn, ) from compose_rl.data.prompt_data import prompt_dataset_collate_fn -from compose_rl.data.messages_data import messages_dataset_collate_fn __all__ = [ 'build_pairwise_preference_dataloader', diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 601f07b9..3085c26a 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -5,12 +5,15 @@ from functools import partial from typing import Any, Callable -import logging from streaming import Stream, StreamingDataLoader, StreamingDataset from torch.utils.data import DataLoader from transformers import PreTrainedTokenizer +from compose_rl.data.messages_data import ( + MessagesStreamingDataset, + messages_dataset_collate_fn, +) from compose_rl.data.preference_data import ( FinegrainedPreferenceStreamingDataset, PairwisePreferenceStreamingDataset, @@ -21,10 +24,6 @@ PromptStreamingDataset, prompt_dataset_collate_fn, ) -from compose_rl.data.messages_data import ( - MessagesStreamingDataset, - messages_dataset_collate_fn, -) __all__ = [ 'build_finegrained_preference_dataloader', @@ -77,12 +76,15 @@ def build_preference_dataloader( streams = None if streams_dict is not None: streams = [Stream(**stream) for stream in streams_dict.values()] - if issubclass(dataset_cls, MessagesStreamingDataset) and 'tokenizer' not in dataset_cfg: + if issubclass( + dataset_cls, + MessagesStreamingDataset, + ) and 'tokenizer' not in dataset_cfg: dataset_cfg['tokenizer'] = tokenizer streaming_dataset = dataset_cls( - streams=streams, - batch_size=device_batch_size, + streams=streams, # type: ignore + batch_size=device_batch_size, # type: ignore **dataset_cfg, ) @@ -123,4 +125,4 @@ def build_preference_dataloader( build_messages_dataloader = generate_dataloader_builder( MessagesStreamingDataset, messages_dataset_collate_fn, -) \ No newline at end of file +) diff --git a/compose_rl/data/messages_data.py b/compose_rl/data/messages_data.py index 300f70d1..07cc2bbf 100644 --- a/compose_rl/data/messages_data.py +++ b/compose_rl/data/messages_data.py @@ -25,7 +25,7 @@ def messages_dataset_collate_fn( tokenizer: PreTrainedTokenizerBase, max_seq_len: int, batch: list[dict[str, Any]], -) -> dict[str, torch.Tensor]: +) -> dict[str, Any]: """Collator for messages data. Args: @@ -41,14 +41,13 @@ def messages_dataset_collate_fn( mlm=False, mlm_probability=0.0, ) - + keys = batch[0].keys() - collated_batch: dict[str, list[str] | torch.Tensor | list[Messages]] = {} + collated_batch: dict[str, Any] = {} for key in keys: cur_values = [item[key] for item in batch] if key in ['prompt_len']: collated_batch[key] = torch.stack(cur_values).squeeze(dim=1) - continue elif key == 'prompt_id': collated_batch[key] = torch.tensor(cur_values) elif key in ['verified_answer']: @@ -57,7 +56,6 @@ def messages_dataset_collate_fn( ) elif key == 'messages': collated_batch[key] = cur_values - continue elif key == 'prompt': collated_batch[key] = ref_collate_fn(cur_values)['input_ids'] else: @@ -90,12 +88,14 @@ def __init__( super().__init__(**kwargs) def _validate_messages(self, messages: Messages) -> bool: - """Validate the messages. A valid message is a list of dictionaries with the following keys: - - role: str - - content: str - + """Validates the message. + + A valid message is a list of dictionaries with + a 'role' key and a 'content' key. + Args: messages (Messages): The messages to validate. + Returns: bool: True if the messages are valid, False otherwise. """ @@ -115,11 +115,14 @@ def _validate_messages(self, messages: Messages) -> bool: def _tokenize_messages(self, messages: Messages) -> torch.Tensor: if not self._validate_messages(messages): raise ValueError(f'Invalid messages received. Got: {messages=}') - return torch.tensor(self.tokenizer.apply_chat_template( - messages, - tokenize=True, - add_generation_prompt=True, - ), dtype=torch.int64) + return torch.tensor( + self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + ), + dtype=torch.int64, + ) # How to process a sample def __getitem__(self, idx: int) -> dict[str, Any]: diff --git a/scripts/data/unified_tokenize_dataset.py b/scripts/data/unified_tokenize_dataset.py index e6d98db2..1c7383b1 100644 --- a/scripts/data/unified_tokenize_dataset.py +++ b/scripts/data/unified_tokenize_dataset.py @@ -42,8 +42,9 @@ def __init__( split: str, tokenizer: PreTrainedTokenizerBase, max_length: int, - dataset_type: Literal['preference', 'single_prompt', - 'verifiable_answers'], + dataset_type: Literal['preference', 'single_prompt', 'single_message', + 'verifiable_answers', 'messages_with_answer', + 'classifier'], subset: str | None = None, token: str | None = None, ): @@ -67,7 +68,9 @@ def __init__( token=token, ) - def __iter__(self) -> Iterator[dict[str, bytes]]: + def __iter__( + self, + ) -> Iterator[dict[str, Any]]: for sample in self.hf_dataset: if self.dataset_type == 'preference': yield self._process_preference_sample(sample) @@ -76,7 +79,10 @@ def __iter__(self) -> Iterator[dict[str, bytes]]: if result is not None: yield result elif self.dataset_type == 'single_message': - result = self._process_single_prompt_sample(sample, return_messages=True) + result = self._process_single_prompt_sample( + sample, + return_messages=True, + ) if result is not None: # delete the prompt from the results since it's not needed result.pop('prompt') @@ -86,7 +92,10 @@ def __iter__(self) -> Iterator[dict[str, bytes]]: if result is not None: yield result elif self.dataset_type == 'messages_with_answer': - result = self._process_verifiable_answer_sample(sample, return_messages=True) + result = self._process_verifiable_answer_sample( + sample, + return_messages=True, + ) if result is not None: # delete the prompt from the results since it's not needed result.pop('prompt') @@ -117,11 +126,16 @@ def _process_preference_sample(self, sample: Any): 'rejected': np.asarray(curr_rejected).tobytes(), } - def _process_single_prompt_sample(self, sample: Any, return_messages: bool = False): + def _process_single_prompt_sample( + self, + sample: Any, + return_messages: bool = False, + ): """Process a prompt sample. Args: sample (Any): a sample from the dataset + return_messages (bool): whether to include chat-ml messages in the output """ prompt = sample['prompt'] messages = [{ @@ -142,7 +156,7 @@ def _process_single_prompt_sample(self, sample: Any, return_messages: bool = Fal output = {'prompt': np.asarray(encoded_prompt).tobytes()} if return_messages: - output['messages'] = messages + output['messages'] = messages # type: ignore return output def _process_classifier_sample(self, sample: Any): @@ -185,13 +199,18 @@ def _get_processing_fn_from_dataset(self): return prompt_fn, answer_fn - def _process_verifiable_answer_sample(self, sample: Any, return_messages: bool = False): + def _process_verifiable_answer_sample( + self, + sample: Any, + return_messages: bool = False, + ): """Process a prompt sample and extract the answer. This function is currently hard-coded for the GSM8K dataset. Args: sample (Any): a sample from the dataset + return_messages (bool): whether to include chat-ml messages in the output """ prompt_fn, answer_fn = self._get_processing_fn_from_dataset() @@ -266,7 +285,9 @@ def main( hashes: list[str], splits: list[str], tokenizer_name: str, - dataset_type: Literal['preference', 'single_prompt', 'single_message', 'verifiable_answers', 'messages_with_answer', 'classifier'], + dataset_type: Literal['preference', 'single_prompt', 'single_message', + 'verifiable_answers', 'messages_with_answer', + 'classifier'], max_length: int = 2048, subset: str | None = None, token: str | None = None, diff --git a/scripts/launch_composer_ray.py b/scripts/launch_composer_ray.py index 7ef0b657..b093c220 100644 --- a/scripts/launch_composer_ray.py +++ b/scripts/launch_composer_ray.py @@ -310,8 +310,10 @@ def reassign_train_and_inference_ranks( os.environ['MASTER_PORT'] = master_port # Adding a ray sync actor on global rank 0 to make it work - sync_actor = SyncActor.options(name='sync_actor', - namespace='default').remote() + sync_actor = SyncActor.options( # type: ignore + name='sync_actor', + namespace='default', + ).remote() log.info('after start ray nodes') diff --git a/tests/common/__init__.py b/tests/common/__init__.py index 1d588da3..a815cb38 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -5,8 +5,8 @@ FineGrainedPreference, PairwisePreference, PromptDataset, - VerifiablePromptDataset, VerifiableMessagesDataset, + VerifiablePromptDataset, ) from tests.common.markers import device, world_size diff --git a/tests/common/datasets.py b/tests/common/datasets.py index f58aee0f..09a49804 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -101,11 +101,18 @@ def __len__(self): return self.size def __getitem__(self, index: int): - messages = [{'role': 'user', 'content': 'F' * self.prompt_len}] # bit of a hack, but it works + messages = [{ + 'role': 'user', + 'content': 'F' * self.prompt_len, + }] # bit of a hack, but it works mock_prompt = torch.ones((len(messages[0]['content']),)).int() return { - 'messages': messages, - 'prompt': mock_prompt, - 'prompt_len': torch.Tensor([len(messages[0]['content'])]).to(torch.int64), - 'verified_answer': 'Paris', - } \ No newline at end of file + 'messages': + messages, + 'prompt': + mock_prompt, + 'prompt_len': + torch.Tensor([len(messages[0]['content'])]).to(torch.int64), + 'verified_answer': + 'Paris', + } diff --git a/tests/test_online.py b/tests/test_online.py index a10318a6..45f81461 100644 --- a/tests/test_online.py +++ b/tests/test_online.py @@ -24,8 +24,16 @@ ) from compose_rl.algorithms.online.model_methods import OnPolicyEnum from compose_rl.algorithms.online.modeling_hf import ComposerHFPolicy -from compose_rl.data import prompt_dataset_collate_fn, messages_dataset_collate_fn -from tests.common import PromptDataset, VerifiablePromptDataset, VerifiableMessagesDataset, world_size +from compose_rl.data import ( + messages_dataset_collate_fn, + prompt_dataset_collate_fn, +) +from tests.common import ( + PromptDataset, + VerifiableMessagesDataset, + VerifiablePromptDataset, + world_size, +) def test_hf_ppo_model_construction( @@ -77,7 +85,10 @@ def test_hf_ppo_policy_construction( @pytest.mark.parametrize('model_type', ['mpt', 'hf']) -@pytest.mark.parametrize('dataset_type', ['prompt', 'verifiable_prompt', 'verifiable_messages']) +@pytest.mark.parametrize( + 'dataset_type', + ['prompt', 'verifiable_prompt', 'verifiable_messages'], +) def test_model_forward( tiny_gpt2_tokenizer: PreTrainedTokenizerBase, model_type: str, From 6d2b4a3794063a623c9498b0674100fc71f2f0a5 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 23 Jun 2025 20:10:48 +0000 Subject: [PATCH 14/19] added messages_dataset_to_mds script --- .../generation_utils/generation_utils.py | 7 +- scripts/data/messages_dataset_to_mds.py | 156 ++++++++++++++++++ scripts/data/messages_preprocessing_utils.py | 32 ++++ 3 files changed, 193 insertions(+), 2 deletions(-) create mode 100644 scripts/data/messages_dataset_to_mds.py create mode 100644 scripts/data/messages_preprocessing_utils.py diff --git a/compose_rl/algorithms/online/generation_utils/generation_utils.py b/compose_rl/algorithms/online/generation_utils/generation_utils.py index 56a7f63e..91574efc 100644 --- a/compose_rl/algorithms/online/generation_utils/generation_utils.py +++ b/compose_rl/algorithms/online/generation_utils/generation_utils.py @@ -120,14 +120,17 @@ def vllm_generate( # Pull the necessary variables from the batch and self cur_device = batch['prompt'].device prompt_tokens = batch['prompt'] + messages = batch['messages'] - prompt_all_gather_start_time = time.time() + prompt_and_messages_all_gather_start_time = time.time() + # TODO: (seank) update this to use gather_object(dst=0) rather than all_gather_object all_batched_prompts = dist.all_gather_object(prompt_tokens) + all_batched_messages = dist.all_gather_object(messages) batch_sizes = [len(batch) for batch in all_batched_prompts] log.info( - f'took : {time.time() - prompt_all_gather_start_time} to gather prompts', + f'took : {time.time() - prompt_and_messages_all_gather_start_time} to gather prompts and messages', ) all_prompts = [prompt for batch in all_batched_prompts for prompt in batch] diff --git a/scripts/data/messages_dataset_to_mds.py b/scripts/data/messages_dataset_to_mds.py new file mode 100644 index 00000000..d9d41ab7 --- /dev/null +++ b/scripts/data/messages_dataset_to_mds.py @@ -0,0 +1,156 @@ +# Copyright 2025 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +"""A unified script to create datasets of messages for different data datasets.""" + +import argparse +import json +import logging +import os +from typing import Any, Iterator + +import datasets as hf_datasets +import fsspec +from streaming import MDSWriter +from torch.utils.data import IterableDataset + +from messages_preprocessing_utils import ( + prepare_gsm8k_messages, + prepare_math_messages, +) + +log = logging.getLogger(__name__) + + +class UnifiedMessagesDataset(IterableDataset): + """An IterableDataset that returns samples as messages with potential additional metadata. + This can take in either an hf dataset or a jsonl file. + + Args: + dataset_path (str): the path to the hf dataset or jsonl file to process + split (str): the split of the hf dataset to process (only used if dataset_path is an hf dataset) + subset (str | None): the subset of the dataset to process (only used if dataset_path is an hf dataset) + """ + + def __init__( + self, + dataset_path: str, + split: str | None = None, + subset: str | None = None, + ): + self.dataset_path = dataset_path + self.split = split + self.subset = subset + self.dataset_preprocess_fn = self.get_preprocess_fn(dataset_path) + self.dataset = self.load_dataset( + dataset_path, + split=split, + subset=subset, + ) + + def load_dataset(self, dataset_path: str, split: str | None = None, subset: str | None = None): + if dataset_path.endswith('.jsonl'): + log.info(f'Assuming dataset path is a jsonl file. Loading from {dataset_path}') + dataset = [] + # Using fsspec to handle both local and remote files + with fsspec.open(dataset_path, 'r', encoding='utf-8') as f: + for line in f: + dataset.append(json.loads(line)) + return dataset + else: + log.info(f'Assuming dataset path is an hf dataset. Loading from {dataset_path} with split: {split} and subset: {subset}') + return hf_datasets.load_dataset( + path=dataset_path, + split=split, + subset=subset, + ) + + def get_preprocess_fn(self, dataset_path: str): + """Returns the preprocessing function for the dataset.""" + if 'gsm8k' in dataset_path: + return prepare_gsm8k_messages + elif 'math' in dataset_path: + return prepare_math_messages + else: + log.warning(f'No preprocessing function found for dataset path: {dataset_path}. Defaulting to writing the dataset as is.') + return lambda x: x + + def __iter__( + self, + ) -> Iterator[dict[str, Any]]: + """Iterate over the dataset and yield samples, with potential preprocessing of the data. + Each sample must be a valid json object with a "messages" key. + """ + for sample in self.dataset: + processed_sample = self.dataset_preprocess_fn(sample) + assert 'messages' in processed_sample, f'Processed sample must have a "messages" key: {processed_sample} for dataset: {self.dataset_path}' + try: + json.loads(json.dumps(processed_sample)) + except Exception as e: + log.error(f'Error converting sample to json: {e}') + log.error(f'Sample: {processed_sample}') + raise e + yield processed_sample + +def main( + dataset_name: str, + compression: str, + local_dir: str, + hashes: list[str], + splits: list[str], + subset: str | None = None, +): + num_written = 0 + for split in splits: + with MDSWriter( + columns={'row': 'json'}, + out=os.path.join(local_dir, split), + compression=compression, + hashes=hashes, + ) as out: + dataset = UnifiedMessagesDataset( + dataset_name=dataset_name, + split=split, + subset=subset, + ) + log.info('Converting to MDS format') + for sample in dataset: + num_written += 1 + out.write(sample) + log.info(f'Finished writing {num_written} samples') + log.info(f'Dataset has: {num_written} samples') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--dataset_name', + type=str, + required=True, + help='Name of the dataset to process', + ) + parser.add_argument('--compression', type=str, default='zstd') + parser.add_argument('--local_dir', type=str, required=True) + parser.add_argument( + '--hashes', + type=str, + nargs='+', + default=['sha1', 'xxh64'], + ) + parser.add_argument('--subset', type=str, default=None) + parser.add_argument('--splits', type=str, nargs='+', default=['train']) + + args = parser.parse_args() + hf_token = os.environ.get('HF_TOKEN') + main( + dataset_name=args.dataset_name, + compression=args.compression, + local_dir=args.local_dir, + hashes=args.hashes, + splits=args.splits, + tokenizer_name=args.tokenizer_name, + dataset_type=args.dataset_type, + max_length=args.max_length, + subset=args.subset, + token=hf_token, + ) diff --git a/scripts/data/messages_preprocessing_utils.py b/scripts/data/messages_preprocessing_utils.py new file mode 100644 index 00000000..c6ae22bb --- /dev/null +++ b/scripts/data/messages_preprocessing_utils.py @@ -0,0 +1,32 @@ +# Copyright 2025 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any +from compose_rl.utils.rlvr_utils import ( + extract_gsm8k_answer, + extract_math_answer, + prepare_gsm8k_prompt, + prepare_math_prompt, +) + +def prepare_gsm8k_messages(sample: Any) -> dict[str, list[dict[str, str]]]: + user_prompt = prepare_gsm8k_prompt(sample) + verified_answer = extract_gsm8k_answer(sample) + messages = [ + { + 'role': 'user', + 'content': user_prompt, + } + ] + return {'messages': messages, 'verified_answer': verified_answer} + +def prepare_math_messages(sample: Any) -> dict[str, Any]: + user_prompt = prepare_math_prompt(sample) + verified_answer = extract_math_answer(sample) + messages = [ + { + 'role': 'user', + 'content': user_prompt, + } + ] + return {'messages': messages, 'verified_answer': verified_answer} From 93556d56454679a042fc23a343c35186bea363f0 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 23 Jun 2025 20:59:34 +0000 Subject: [PATCH 15/19] added option for metadata in messages dataset --- compose_rl/data/messages_data.py | 12 +++- scripts/data/messages_dataset_to_mds.py | 64 ++++++++++++-------- scripts/data/messages_preprocessing_utils.py | 16 +++-- 3 files changed, 60 insertions(+), 32 deletions(-) diff --git a/compose_rl/data/messages_data.py b/compose_rl/data/messages_data.py index 07cc2bbf..87914a59 100644 --- a/compose_rl/data/messages_data.py +++ b/compose_rl/data/messages_data.py @@ -126,13 +126,14 @@ def _tokenize_messages(self, messages: Messages) -> torch.Tensor: # How to process a sample def __getitem__(self, idx: int) -> dict[str, Any]: - """Get an item from StreamingDataset at a given index. + """Gets the messages and metadata from the StreamingDataset at a given index. Args: idx (int): the index where we fetch the data in the StreamingDataset. """ sample = super().__getitem__(idx) messages = sample['messages'] + metadata = sample.get('metadata', {}) prompt: torch.Tensor = self._tokenize_messages(messages) # TODO (bcui): Maybe add in an option to truncate a prompt by a given length? @@ -142,15 +143,20 @@ def __getitem__(self, idx: int) -> dict[str, Any]: prompt = prompt[:-truncate_len] prompt_len = torch.Tensor([len(prompt)]).to(dtype=torch.int64) - # Send the prompt id along with prompt data + # Create the return dictionary item_dict = { 'prompt_id': idx, 'prompt': prompt, 'prompt_len': prompt_len, 'messages': messages, + 'metadata': metadata, } - verified_answer = sample.get('verified_answer', None) + # extract the verified answer, if there is one + verified_answer = metadata.get('verified_answer', None) + if verified_answer is None: + # for backwards compatibility, we also check the sample directly for the verified answer + verified_answer = sample.get('verified_answer', None) if verified_answer: if isinstance(verified_answer, str): _answer = verified_answer diff --git a/scripts/data/messages_dataset_to_mds.py b/scripts/data/messages_dataset_to_mds.py index d9d41ab7..49797ea7 100644 --- a/scripts/data/messages_dataset_to_mds.py +++ b/scripts/data/messages_dataset_to_mds.py @@ -37,6 +37,7 @@ def __init__( dataset_path: str, split: str | None = None, subset: str | None = None, + token: str | None = None, ): self.dataset_path = dataset_path self.split = split @@ -46,9 +47,10 @@ def __init__( dataset_path, split=split, subset=subset, + token=token, ) - def load_dataset(self, dataset_path: str, split: str | None = None, subset: str | None = None): + def load_dataset(self, dataset_path: str, split: str | None = None, subset: str | None = None, token: str | None = None): if dataset_path.endswith('.jsonl'): log.info(f'Assuming dataset path is a jsonl file. Loading from {dataset_path}') dataset = [] @@ -62,38 +64,54 @@ def load_dataset(self, dataset_path: str, split: str | None = None, subset: str return hf_datasets.load_dataset( path=dataset_path, split=split, - subset=subset, + name=subset, + streaming=True, + token=token, ) def get_preprocess_fn(self, dataset_path: str): - """Returns the preprocessing function for the dataset.""" + """Returns the preprocessing function for the dataset. + + Each preprocessing function should return a tuple of (messages, metadata). + Messages should be a list of dictionaries with a 'role' key and a 'content' key. + Metadata should be a dictionary with any additional metadata. If there is no metadata, then the metadata can just be None. + Both the messages and metadata (if not None)must be json serializable. + + Args: + dataset_path (str): the path to the dataset + + Returns: + A function that takes in a sample and returns a tuple of (messages, metadata). + """ if 'gsm8k' in dataset_path: return prepare_gsm8k_messages elif 'math' in dataset_path: return prepare_math_messages else: log.warning(f'No preprocessing function found for dataset path: {dataset_path}. Defaulting to writing the dataset as is.') - return lambda x: x + return lambda x: (x, None) def __iter__( self, ) -> Iterator[dict[str, Any]]: """Iterate over the dataset and yield samples, with potential preprocessing of the data. - Each sample must be a valid json object with a "messages" key. """ for sample in self.dataset: - processed_sample = self.dataset_preprocess_fn(sample) - assert 'messages' in processed_sample, f'Processed sample must have a "messages" key: {processed_sample} for dataset: {self.dataset_path}' - try: - json.loads(json.dumps(processed_sample)) - except Exception as e: - log.error(f'Error converting sample to json: {e}') - log.error(f'Sample: {processed_sample}') - raise e - yield processed_sample + messages, metadata = self.dataset_preprocess_fn(sample) + if metadata is None: + metadata = {} + # time for some good ol fashioned type checking + for item, name in zip([messages, metadata], ['messages', 'metadata']): + try: + json.loads(json.dumps(item)) + except Exception as e: + log.error(f'Error converting {name} to json: {e}') + log.error(f'{name}: {item}') + raise e + yield {'messages': messages, 'metadata': metadata} def main( - dataset_name: str, + dataset_path: str, compression: str, local_dir: str, hashes: list[str], @@ -103,13 +121,13 @@ def main( num_written = 0 for split in splits: with MDSWriter( - columns={'row': 'json'}, + columns={'messages': 'json', 'metadata': 'json'}, out=os.path.join(local_dir, split), compression=compression, hashes=hashes, ) as out: dataset = UnifiedMessagesDataset( - dataset_name=dataset_name, + dataset_path=dataset_path, split=split, subset=subset, ) @@ -118,16 +136,16 @@ def main( num_written += 1 out.write(sample) log.info(f'Finished writing {num_written} samples') - log.info(f'Dataset has: {num_written} samples') + log.info(f'Dataset has {num_written} samples') if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - '--dataset_name', + '--dataset_path', type=str, required=True, - help='Name of the dataset to process', + help='Path to the dataset to process', ) parser.add_argument('--compression', type=str, default='zstd') parser.add_argument('--local_dir', type=str, required=True) @@ -143,14 +161,10 @@ def main( args = parser.parse_args() hf_token = os.environ.get('HF_TOKEN') main( - dataset_name=args.dataset_name, + dataset_path=args.dataset_path, compression=args.compression, local_dir=args.local_dir, hashes=args.hashes, splits=args.splits, - tokenizer_name=args.tokenizer_name, - dataset_type=args.dataset_type, - max_length=args.max_length, subset=args.subset, - token=hf_token, ) diff --git a/scripts/data/messages_preprocessing_utils.py b/scripts/data/messages_preprocessing_utils.py index c6ae22bb..c252d1ad 100644 --- a/scripts/data/messages_preprocessing_utils.py +++ b/scripts/data/messages_preprocessing_utils.py @@ -1,5 +1,13 @@ # Copyright 2025 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 +""" +Preprocessing functions for the messages dataset. + +Each preprocessing function should return a tuple of (messages, metadata). +Messages should be a list of dictionaries with a 'role' key and a 'content' key. +Metadata should be a dictionary with any additional metadata. If there is no metadata, then return an empty dictionary. +Both the messages and metadata must be json serializable. +""" from typing import Any from compose_rl.utils.rlvr_utils import ( @@ -9,7 +17,7 @@ prepare_math_prompt, ) -def prepare_gsm8k_messages(sample: Any) -> dict[str, list[dict[str, str]]]: +def prepare_gsm8k_messages(sample: Any) -> tuple[list[dict[str, str]], dict[str, str]]: user_prompt = prepare_gsm8k_prompt(sample) verified_answer = extract_gsm8k_answer(sample) messages = [ @@ -18,9 +26,9 @@ def prepare_gsm8k_messages(sample: Any) -> dict[str, list[dict[str, str]]]: 'content': user_prompt, } ] - return {'messages': messages, 'verified_answer': verified_answer} + return messages, {'verified_answer': verified_answer} -def prepare_math_messages(sample: Any) -> dict[str, Any]: +def prepare_math_messages(sample: Any) -> tuple[list[dict[str, str]], dict[str, str]]: user_prompt = prepare_math_prompt(sample) verified_answer = extract_math_answer(sample) messages = [ @@ -29,4 +37,4 @@ def prepare_math_messages(sample: Any) -> dict[str, Any]: 'content': user_prompt, } ] - return {'messages': messages, 'verified_answer': verified_answer} + return messages, {'verified_answer': verified_answer} From 4f3ae651d4a8710b0510f15a35c67f10b608d8d7 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 23 Jun 2025 21:42:37 +0000 Subject: [PATCH 16/19] updated readme and made pyright happy --- README.md | 15 +++- compose_rl/data/messages_data.py | 2 +- scripts/data/messages_dataset_to_mds.py | 79 +++++++++++++------- scripts/data/messages_preprocessing_utils.py | 24 ++++-- 4 files changed, 78 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 501be8ca..86189d5a 100644 --- a/README.md +++ b/README.md @@ -77,20 +77,27 @@ To further enable online RL with [verifiable rewards](https://arxiv.org/abs/2411 ```bash cd scripts -python data/unified_tokenize_dataset.py --dataset_name \ +python data/messages_dataset_to_mds.py --dataset_path \ --local_dir verifiable_data \ ---dataset_type messages_with_answer \ ---tokenizer_name meta-llama/Llama-3.1-8B-Instruct \ --split train \ ``` -We currently support the following two HuggingFace datasets for verifiable rewards: +For RLVR, We currently support the following two HuggingFace datasets for verifiable rewards: - GMS8k: `openai/gsm8k` - MATH: `DigitalLearningGmbH/MATH-lighteval` The data preparation scripts also supports additional arguments for specifying the subset of the HuggingFace dataset `--subset ` and max sequence length `--max_length ` +For custom datasets, you can create a custom preprocessing function in the `compose-rl/scripts/data/messages_preprocessing_utils.py` file, or you can preprocess your own dataset directly, save it locally as a .jsonl file, and then use the following command to convert it to the MDS format: + +```bash +cd scripts +python data/messages_dataset_to_mds.py --dataset_path \ +--local_dir custom_dataset \ +``` + + ### Model training Below are the scripts to launch training runs assuming you ran the data preparation scripts above. Additionally, these scripts assume that we are in the root directory where Compose RL and LLM Foundry were cloned. This is because we utilize [LLM Foundry's Registry System](https://github.com/mosaicml/llm-foundry/?tab=readme-ov-file#registry) in order to take advantage of existing features in LLM Foundry. diff --git a/compose_rl/data/messages_data.py b/compose_rl/data/messages_data.py index 87914a59..f2b639d9 100644 --- a/compose_rl/data/messages_data.py +++ b/compose_rl/data/messages_data.py @@ -126,7 +126,7 @@ def _tokenize_messages(self, messages: Messages) -> torch.Tensor: # How to process a sample def __getitem__(self, idx: int) -> dict[str, Any]: - """Gets the messages and metadata from the StreamingDataset at a given index. + """Gets the messages + (optionally) metadata at the given index. Args: idx (int): the index where we fetch the data in the StreamingDataset. diff --git a/scripts/data/messages_dataset_to_mds.py b/scripts/data/messages_dataset_to_mds.py index 49797ea7..c35c06a5 100644 --- a/scripts/data/messages_dataset_to_mds.py +++ b/scripts/data/messages_dataset_to_mds.py @@ -1,7 +1,7 @@ # Copyright 2025 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 -"""A unified script to create datasets of messages for different data datasets.""" +"""A unified script to create messages-based datasets with Mosaic Streaming.""" import argparse import json @@ -11,20 +11,18 @@ import datasets as hf_datasets import fsspec -from streaming import MDSWriter -from torch.utils.data import IterableDataset - from messages_preprocessing_utils import ( prepare_gsm8k_messages, prepare_math_messages, ) +from streaming import MDSWriter +from torch.utils.data import IterableDataset log = logging.getLogger(__name__) class UnifiedMessagesDataset(IterableDataset): - """An IterableDataset that returns samples as messages with potential additional metadata. - This can take in either an hf dataset or a jsonl file. + """An IterableDataset that returns messages + (optionally) metadata. Args: dataset_path (str): the path to the hf dataset or jsonl file to process @@ -50,17 +48,26 @@ def __init__( token=token, ) - def load_dataset(self, dataset_path: str, split: str | None = None, subset: str | None = None, token: str | None = None): + def load_dataset( + self, + dataset_path: str, + split: str | None = None, + subset: str | None = None, + token: str | None = None, + ): if dataset_path.endswith('.jsonl'): - log.info(f'Assuming dataset path is a jsonl file. Loading from {dataset_path}') + log.info( + f'Assuming dataset path is a jsonl file. Loading from {dataset_path}', + ) dataset = [] # Using fsspec to handle both local and remote files with fsspec.open(dataset_path, 'r', encoding='utf-8') as f: - for line in f: - dataset.append(json.loads(line)) + dataset = [json.loads(line) for line in f] return dataset else: - log.info(f'Assuming dataset path is an hf dataset. Loading from {dataset_path} with split: {split} and subset: {subset}') + log.info( + f'Assuming dataset path is an hf dataset. Loading from {dataset_path} with split: {split} and subset: {subset}', + ) return hf_datasets.load_dataset( path=dataset_path, split=split, @@ -68,10 +75,10 @@ def load_dataset(self, dataset_path: str, split: str | None = None, subset: str streaming=True, token=token, ) - + def get_preprocess_fn(self, dataset_path: str): - """Returns the preprocessing function for the dataset. - + """Returns the preprocessing function for the dataset. + Each preprocessing function should return a tuple of (messages, metadata). Messages should be a list of dictionaries with a 'role' key and a 'content' key. Metadata should be a dictionary with any additional metadata. If there is no metadata, then the metadata can just be None. @@ -83,33 +90,44 @@ def get_preprocess_fn(self, dataset_path: str): Returns: A function that takes in a sample and returns a tuple of (messages, metadata). """ - if 'gsm8k' in dataset_path: + if 'gsm8k' in dataset_path.lower(): + log.info('Using GSM8k preprocessing function') return prepare_gsm8k_messages - elif 'math' in dataset_path: + elif 'math' in dataset_path.lower(): + log.info('Using MATH preprocessing function') return prepare_math_messages else: - log.warning(f'No preprocessing function found for dataset path: {dataset_path}. Defaulting to writing the dataset as is.') + log.warning( + f'No preprocessing function found for dataset path: {dataset_path}. Defaulting to writing the dataset as is.', + ) return lambda x: (x, None) def __iter__( self, ) -> Iterator[dict[str, Any]]: - """Iterate over the dataset and yield samples, with potential preprocessing of the data. - """ + """Iteratively yields messages + (optionally) metadata.""" for sample in self.dataset: messages, metadata = self.dataset_preprocess_fn(sample) if metadata is None: metadata = {} - # time for some good ol fashioned type checking - for item, name in zip([messages, metadata], ['messages', 'metadata']): - try: - json.loads(json.dumps(item)) - except Exception as e: - log.error(f'Error converting {name} to json: {e}') - log.error(f'{name}: {item}') - raise e + + # time for some good ol fashioned validation + self._ensure_jsonable(messages, 'messages') + self._ensure_jsonable(metadata, 'metadata') + yield {'messages': messages, 'metadata': metadata} + @staticmethod + def _ensure_jsonable(obj: Any, label: str) -> None: + """Raise ValueError if obj cannot be round-tripped through JSON.""" + try: + json.loads(json.dumps(obj)) + except Exception as e: + log.error(f'Error converting {label} to JSON: {e}') + log.error(f'{label}: {obj}') + raise ValueError(f'{label} is not JSON-serializable') from e + + def main( dataset_path: str, compression: str, @@ -121,7 +139,10 @@ def main( num_written = 0 for split in splits: with MDSWriter( - columns={'messages': 'json', 'metadata': 'json'}, + columns={ + 'messages': 'json', + 'metadata': 'json', + }, out=os.path.join(local_dir, split), compression=compression, hashes=hashes, @@ -157,7 +178,7 @@ def main( ) parser.add_argument('--subset', type=str, default=None) parser.add_argument('--splits', type=str, nargs='+', default=['train']) - + args = parser.parse_args() hf_token = os.environ.get('HF_TOKEN') main( diff --git a/scripts/data/messages_preprocessing_utils.py b/scripts/data/messages_preprocessing_utils.py index c252d1ad..316bcf86 100644 --- a/scripts/data/messages_preprocessing_utils.py +++ b/scripts/data/messages_preprocessing_utils.py @@ -1,15 +1,17 @@ # Copyright 2025 MosaicML ComposeRL authors # SPDX-License-Identifier: Apache-2.0 -""" -Preprocessing functions for the messages dataset. + +"""Preprocessing functions for the messages dataset. Each preprocessing function should return a tuple of (messages, metadata). Messages should be a list of dictionaries with a 'role' key and a 'content' key. -Metadata should be a dictionary with any additional metadata. If there is no metadata, then return an empty dictionary. -Both the messages and metadata must be json serializable. +Metadata should be a dictionary with any additional metadata. If there is no +metadata, then return an empty dictionary. Both the messages and metadata must +be json serializable. """ from typing import Any + from compose_rl.utils.rlvr_utils import ( extract_gsm8k_answer, extract_math_answer, @@ -17,24 +19,30 @@ prepare_math_prompt, ) -def prepare_gsm8k_messages(sample: Any) -> tuple[list[dict[str, str]], dict[str, str]]: + +def prepare_gsm8k_messages( + sample: Any, +) -> tuple[list[dict[str, str]], dict[str, str | None]]: user_prompt = prepare_gsm8k_prompt(sample) verified_answer = extract_gsm8k_answer(sample) messages = [ { 'role': 'user', 'content': user_prompt, - } + }, ] return messages, {'verified_answer': verified_answer} -def prepare_math_messages(sample: Any) -> tuple[list[dict[str, str]], dict[str, str]]: + +def prepare_math_messages( + sample: Any, +) -> tuple[list[dict[str, str]], dict[str, str | None]]: user_prompt = prepare_math_prompt(sample) verified_answer = extract_math_answer(sample) messages = [ { 'role': 'user', 'content': user_prompt, - } + }, ] return messages, {'verified_answer': verified_answer} From 3c5283fceee4cbae212a55352de9422e0f6fea72 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 23 Jun 2025 21:42:51 +0000 Subject: [PATCH 17/19] undid accidental commit to generation_utils --- .../algorithms/online/generation_utils/generation_utils.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/online/generation_utils/generation_utils.py b/compose_rl/algorithms/online/generation_utils/generation_utils.py index 91574efc..56a7f63e 100644 --- a/compose_rl/algorithms/online/generation_utils/generation_utils.py +++ b/compose_rl/algorithms/online/generation_utils/generation_utils.py @@ -120,17 +120,14 @@ def vllm_generate( # Pull the necessary variables from the batch and self cur_device = batch['prompt'].device prompt_tokens = batch['prompt'] - messages = batch['messages'] - prompt_and_messages_all_gather_start_time = time.time() + prompt_all_gather_start_time = time.time() - # TODO: (seank) update this to use gather_object(dst=0) rather than all_gather_object all_batched_prompts = dist.all_gather_object(prompt_tokens) - all_batched_messages = dist.all_gather_object(messages) batch_sizes = [len(batch) for batch in all_batched_prompts] log.info( - f'took : {time.time() - prompt_and_messages_all_gather_start_time} to gather prompts and messages', + f'took : {time.time() - prompt_all_gather_start_time} to gather prompts', ) all_prompts = [prompt for batch in all_batched_prompts for prompt in batch] From f241fb0785b71b10ca5b2c7bfd44985f52a042c6 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 23 Jun 2025 22:23:02 +0000 Subject: [PATCH 18/19] removed metadata from batch for now --- compose_rl/data/messages_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/messages_data.py b/compose_rl/data/messages_data.py index f2b639d9..16b836ff 100644 --- a/compose_rl/data/messages_data.py +++ b/compose_rl/data/messages_data.py @@ -149,7 +149,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: 'prompt': prompt, 'prompt_len': prompt_len, 'messages': messages, - 'metadata': metadata, + # 'metadata': metadata, # removing metadata for now } # extract the verified answer, if there is one From cc546a772d558acb16094ec30bd03af102858e92 Mon Sep 17 00:00:00 2001 From: SeanKski Date: Mon, 23 Jun 2025 22:23:22 +0000 Subject: [PATCH 19/19] fixed messages collating based on brandon's comment --- compose_rl/algorithms/online/callback.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 78fb1970..a655f895 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -683,15 +683,7 @@ def _get_next_iter_prompts(self): else: if key == 'verified_answer': ret_batch[key] = list(flatten(curr_values)) - elif key == 'messages': - # the messages should be [num_batches_per_update, batch_size, num_turns] - # need to flatten this to [num_batches_per_update * batch_size, num_turns] - ret_batch[key] = [ - message_chain for batch in curr_values - for message_chain in batch - ] else: - # this is an edge case that we will not hit currently, but just handling it as needed ret_batch[key] = curr_values return ret_batch