diff --git a/README.md b/README.md index 299c2200..86189d5a 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ Here is an end-to-end workflow of performing data preparation for training along ### Data preparation -Below is the set of commands to run to prepare datasets into the appropriate Mosaic Data Shard (MDS) format, which is a pre-tokenized version of the data, that we will use for training. +Below is the set of commands to run to prepare datasets into the appropriate Mosaic Data Shard (MDS) format, which is either a pre-tokenized version of the data or the raw chat-message data, that we will use for training. Below is the command to prepare preference data -- which can be used for reward model or offline RL (e.g. DPO) training: @@ -60,37 +60,44 @@ python data/unified_tokenize_dataset.py --dataset_name allenai/ultrafeedback_bin --split train_prefs ``` -Below is the command to prepare prompt data -- which can be used for online RL (e.g. PPO) training: +Below is the command to prepare single-turn message data -- which can be used for online RL (e.g. PPO) training: ```bash cd scripts python data/unified_tokenize_dataset.py --dataset_name allenai/ultrafeedback_binarized_cleaned \ ---local_dir prompt_data \ ---dataset_type single_prompt \ +--local_dir single_message_data \ +--dataset_type single_message \ --tokenizer_name meta-llama/Llama-3.1-8B-Instruct \ --split train_prefs ``` -To further enable online RL with [verifiable rewards](https://arxiv.org/abs/2411.15124) you can use the following command: +To further enable online RL with [verifiable rewards](https://arxiv.org/abs/2411.15124) you can use the following command to prepare the chat-message data and their corresponding verifiable answers: ```bash cd scripts -python data/unified_tokenize_dataset.py --dataset_name \ +python data/messages_dataset_to_mds.py --dataset_path \ --local_dir verifiable_data \ ---dataset_type verifiable_answers \ ---tokenizer_name meta-llama/Llama-3.1-8B-Instruct \ --split train \ ``` -We currently support the following two HuggingFace datasets for verifiable rewards: +For RLVR, We currently support the following two HuggingFace datasets for verifiable rewards: - GMS8k: `openai/gsm8k` - MATH: `DigitalLearningGmbH/MATH-lighteval` The data preparation scripts also supports additional arguments for specifying the subset of the HuggingFace dataset `--subset ` and max sequence length `--max_length ` +For custom datasets, you can create a custom preprocessing function in the `compose-rl/scripts/data/messages_preprocessing_utils.py` file, or you can preprocess your own dataset directly, save it locally as a .jsonl file, and then use the following command to convert it to the MDS format: + +```bash +cd scripts +python data/messages_dataset_to_mds.py --dataset_path \ +--local_dir custom_dataset \ +``` + + ### Model training Below are the scripts to launch training runs assuming you ran the data preparation scripts above. Additionally, these scripts assume that we are in the root directory where Compose RL and LLM Foundry were cloned. This is because we utilize [LLM Foundry's Registry System](https://github.com/mosaicml/llm-foundry/?tab=readme-ov-file#registry) in order to take advantage of existing features in LLM Foundry. @@ -129,7 +136,7 @@ Below is the command to run Online PPO training: ```bash composer llm-foundry/scripts/train/train.py \ compose-rl/yamls/local_ppo.yaml \ -train_loader.dataset.local=/compose-rl/scripts/prompt_data/ \ +train_loader.dataset.local=/compose-rl/scripts/single_message_data/ \ train_loader.dataset.split=train_prefs ``` diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 86019263..29bbf7a5 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -649,7 +649,12 @@ def _get_next_iter_prompts(self): # Explode the batch into multiple batches for each generation for _ in range(self.generations_per_prompt): # For keys that do not require additional processing - if key in ['prompt_len', 'verified_answer', 'prompt_id']: + if key in [ + 'prompt_len', + 'verified_answer', + 'prompt_id', + 'messages', + ]: curr_values.append(batch[key]) continue @@ -680,7 +685,6 @@ def _get_next_iter_prompts(self): if key == 'verified_answer': ret_batch[key] = list(flatten(curr_values)) else: - # this is an edge case that we will not hit currently, but just handling it as needed ret_batch[key] = curr_values return ret_batch diff --git a/compose_rl/algorithms/online/generation_utils/vllm_utils.py b/compose_rl/algorithms/online/generation_utils/vllm_utils.py index b844480d..3349ca0a 100644 --- a/compose_rl/algorithms/online/generation_utils/vllm_utils.py +++ b/compose_rl/algorithms/online/generation_utils/vllm_utils.py @@ -239,7 +239,7 @@ def create_vllm_engines( log.info(f'vllm: {num_gpus=}, {num_engines=}') vllm_engines.append( - LLMRayActor.options( + LLMRayActor.options( # type: ignore num_cpus=num_gpus, num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, diff --git a/compose_rl/data/__init__.py b/compose_rl/data/__init__.py index 33b1b0c1..5031dd10 100644 --- a/compose_rl/data/__init__.py +++ b/compose_rl/data/__init__.py @@ -7,9 +7,11 @@ ) from compose_rl.data.dataloader import ( build_finegrained_preference_dataloader, + build_messages_dataloader, build_pairwise_preference_dataloader, build_prompt_dataloader, ) +from compose_rl.data.messages_data import messages_dataset_collate_fn from compose_rl.data.preference_data import ( finegrained_preference_dataset_collate_fn, pairwise_preference_dataset_collate_fn, @@ -19,10 +21,12 @@ __all__ = [ 'build_pairwise_preference_dataloader', 'build_finegrained_preference_dataloader', + 'build_messages_dataloader', 'build_prompt_dataloader', 'DummyDataset', 'finegrained_preference_dataset_collate_fn', 'MinibatchRolloutBuffer', 'pairwise_preference_dataset_collate_fn', 'prompt_dataset_collate_fn', + 'messages_dataset_collate_fn', ] diff --git a/compose_rl/data/dataloader.py b/compose_rl/data/dataloader.py index 2fd37184..3085c26a 100644 --- a/compose_rl/data/dataloader.py +++ b/compose_rl/data/dataloader.py @@ -10,6 +10,10 @@ from torch.utils.data import DataLoader from transformers import PreTrainedTokenizer +from compose_rl.data.messages_data import ( + MessagesStreamingDataset, + messages_dataset_collate_fn, +) from compose_rl.data.preference_data import ( FinegrainedPreferenceStreamingDataset, PairwisePreferenceStreamingDataset, @@ -25,6 +29,7 @@ 'build_finegrained_preference_dataloader', 'build_pairwise_preference_dataloader', 'build_prompt_dataloader', + 'build_messages_dataloader', ] @@ -71,10 +76,15 @@ def build_preference_dataloader( streams = None if streams_dict is not None: streams = [Stream(**stream) for stream in streams_dict.values()] + if issubclass( + dataset_cls, + MessagesStreamingDataset, + ) and 'tokenizer' not in dataset_cfg: + dataset_cfg['tokenizer'] = tokenizer streaming_dataset = dataset_cls( - streams=streams, - batch_size=device_batch_size, + streams=streams, # type: ignore + batch_size=device_batch_size, # type: ignore **dataset_cfg, ) @@ -111,3 +121,8 @@ def build_preference_dataloader( PromptStreamingDataset, prompt_dataset_collate_fn, ) + +build_messages_dataloader = generate_dataloader_builder( + MessagesStreamingDataset, + messages_dataset_collate_fn, +) diff --git a/compose_rl/data/messages_data.py b/compose_rl/data/messages_data.py new file mode 100644 index 00000000..16b836ff --- /dev/null +++ b/compose_rl/data/messages_data.py @@ -0,0 +1,174 @@ +# Copyright 2025 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +"""Build a prompt dataset and dataloader for training.""" + +import logging +from typing import Any, TypeAlias + +import torch +from streaming import StreamingDataset +from transformers import ( + AutoTokenizer, + DataCollatorForLanguageModeling, + PreTrainedTokenizerBase, +) + +import compose_rl.utils as utils + +log = logging.getLogger(__name__) + +Messages: TypeAlias = list[dict[str, str]] + + +def messages_dataset_collate_fn( + tokenizer: PreTrainedTokenizerBase, + max_seq_len: int, + batch: list[dict[str, Any]], +) -> dict[str, Any]: + """Collator for messages data. + + Args: + batch (List[Dict[str, Any]]): A list of data samples to collate. + tokenizer (PreTrainedTokenizer): The model's tokenizer. + max_seq_len (int): The maximum sequence length of the model. + """ + if tokenizer.pad_token_id is None: + raise ValueError('Tokenizer must have a PAD token.') + + ref_collate_fn = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + mlm_probability=0.0, + ) + + keys = batch[0].keys() + collated_batch: dict[str, Any] = {} + for key in keys: + cur_values = [item[key] for item in batch] + if key in ['prompt_len']: + collated_batch[key] = torch.stack(cur_values).squeeze(dim=1) + elif key == 'prompt_id': + collated_batch[key] = torch.tensor(cur_values) + elif key in ['verified_answer']: + collated_batch[key] = list( # pyright: ignore[reportGeneralTypeIssues] + utils.flatten(cur_values), + ) + elif key == 'messages': + collated_batch[key] = cur_values + elif key == 'prompt': + collated_batch[key] = ref_collate_fn(cur_values)['input_ids'] + else: + raise ValueError(f'Invalid key: {key}') + + collated_batch['prompt_attention_mask'] = torch.logical_not( + torch.eq(collated_batch['prompt'], + tokenizer.pad_token_id), # type: ignore + ) + + return collated_batch + + +class MessagesStreamingDataset(StreamingDataset): + """Dataloader for streaming in messages and converting to prompts.""" + + def __init__( + self, + max_gen_len: int, + max_seq_len: int, + tokenizer: str | PreTrainedTokenizerBase, + **kwargs: dict[str, Any], + ): + self.max_gen_len = max_gen_len + self.max_seq_len = max_seq_len + if isinstance(tokenizer, str): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) + else: + self.tokenizer = tokenizer + super().__init__(**kwargs) + + def _validate_messages(self, messages: Messages) -> bool: + """Validates the message. + + A valid message is a list of dictionaries with + a 'role' key and a 'content' key. + + Args: + messages (Messages): The messages to validate. + + Returns: + bool: True if the messages are valid, False otherwise. + """ + if not isinstance(messages, list): + return False + for message in messages: + if not isinstance(message, dict): + return False + if 'role' not in message: + return False + if 'content' not in message: + return False + if not isinstance(message['content'], str): + return False + return True + + def _tokenize_messages(self, messages: Messages) -> torch.Tensor: + if not self._validate_messages(messages): + raise ValueError(f'Invalid messages received. Got: {messages=}') + return torch.tensor( + self.tokenizer.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + ), + dtype=torch.int64, + ) + + # How to process a sample + def __getitem__(self, idx: int) -> dict[str, Any]: + """Gets the messages + (optionally) metadata at the given index. + + Args: + idx (int): the index where we fetch the data in the StreamingDataset. + """ + sample = super().__getitem__(idx) + messages = sample['messages'] + metadata = sample.get('metadata', {}) + prompt: torch.Tensor = self._tokenize_messages(messages) + + # TODO (bcui): Maybe add in an option to truncate a prompt by a given length? + if len(prompt) + self.max_gen_len > self.max_seq_len: + truncate_len = len(prompt) + self.max_gen_len - self.max_seq_len + log.info(f'Truncating prompt by: {truncate_len}') + prompt = prompt[:-truncate_len] + + prompt_len = torch.Tensor([len(prompt)]).to(dtype=torch.int64) + # Create the return dictionary + item_dict = { + 'prompt_id': idx, + 'prompt': prompt, + 'prompt_len': prompt_len, + 'messages': messages, + # 'metadata': metadata, # removing metadata for now + } + + # extract the verified answer, if there is one + verified_answer = metadata.get('verified_answer', None) + if verified_answer is None: + # for backwards compatibility, we also check the sample directly for the verified answer + verified_answer = sample.get('verified_answer', None) + if verified_answer: + if isinstance(verified_answer, str): + _answer = verified_answer + else: + try: + _answer = verified_answer.decode('utf-8', errors='strict') + except UnicodeDecodeError as e: + log.error( + f'Failed to decode verifed_answer with error: {e}', + ) + _answer = '' + + item_dict['verified_answer'] = _answer # type: ignore + + return item_dict diff --git a/pyproject.toml b/pyproject.toml index 038ff73b..f13f1efe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ hf_ppo_lm = "compose_rl.algorithms.online:ComposerHFPolicyLM" pairwise_preference = "compose_rl.data:build_pairwise_preference_dataloader" finegrained_preference = "compose_rl.data:build_finegrained_preference_dataloader" prompt = "compose_rl.data:build_prompt_dataloader" +messages = "compose_rl.data:build_messages_dataloader" [project.entry-points."llmfoundry_callbacks_with_config"] offline_rl = "compose_rl.algorithms.offline:ReferencePolicyCallback" diff --git a/scripts/data/messages_dataset_to_mds.py b/scripts/data/messages_dataset_to_mds.py new file mode 100644 index 00000000..c35c06a5 --- /dev/null +++ b/scripts/data/messages_dataset_to_mds.py @@ -0,0 +1,191 @@ +# Copyright 2025 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +"""A unified script to create messages-based datasets with Mosaic Streaming.""" + +import argparse +import json +import logging +import os +from typing import Any, Iterator + +import datasets as hf_datasets +import fsspec +from messages_preprocessing_utils import ( + prepare_gsm8k_messages, + prepare_math_messages, +) +from streaming import MDSWriter +from torch.utils.data import IterableDataset + +log = logging.getLogger(__name__) + + +class UnifiedMessagesDataset(IterableDataset): + """An IterableDataset that returns messages + (optionally) metadata. + + Args: + dataset_path (str): the path to the hf dataset or jsonl file to process + split (str): the split of the hf dataset to process (only used if dataset_path is an hf dataset) + subset (str | None): the subset of the dataset to process (only used if dataset_path is an hf dataset) + """ + + def __init__( + self, + dataset_path: str, + split: str | None = None, + subset: str | None = None, + token: str | None = None, + ): + self.dataset_path = dataset_path + self.split = split + self.subset = subset + self.dataset_preprocess_fn = self.get_preprocess_fn(dataset_path) + self.dataset = self.load_dataset( + dataset_path, + split=split, + subset=subset, + token=token, + ) + + def load_dataset( + self, + dataset_path: str, + split: str | None = None, + subset: str | None = None, + token: str | None = None, + ): + if dataset_path.endswith('.jsonl'): + log.info( + f'Assuming dataset path is a jsonl file. Loading from {dataset_path}', + ) + dataset = [] + # Using fsspec to handle both local and remote files + with fsspec.open(dataset_path, 'r', encoding='utf-8') as f: + dataset = [json.loads(line) for line in f] + return dataset + else: + log.info( + f'Assuming dataset path is an hf dataset. Loading from {dataset_path} with split: {split} and subset: {subset}', + ) + return hf_datasets.load_dataset( + path=dataset_path, + split=split, + name=subset, + streaming=True, + token=token, + ) + + def get_preprocess_fn(self, dataset_path: str): + """Returns the preprocessing function for the dataset. + + Each preprocessing function should return a tuple of (messages, metadata). + Messages should be a list of dictionaries with a 'role' key and a 'content' key. + Metadata should be a dictionary with any additional metadata. If there is no metadata, then the metadata can just be None. + Both the messages and metadata (if not None)must be json serializable. + + Args: + dataset_path (str): the path to the dataset + + Returns: + A function that takes in a sample and returns a tuple of (messages, metadata). + """ + if 'gsm8k' in dataset_path.lower(): + log.info('Using GSM8k preprocessing function') + return prepare_gsm8k_messages + elif 'math' in dataset_path.lower(): + log.info('Using MATH preprocessing function') + return prepare_math_messages + else: + log.warning( + f'No preprocessing function found for dataset path: {dataset_path}. Defaulting to writing the dataset as is.', + ) + return lambda x: (x, None) + + def __iter__( + self, + ) -> Iterator[dict[str, Any]]: + """Iteratively yields messages + (optionally) metadata.""" + for sample in self.dataset: + messages, metadata = self.dataset_preprocess_fn(sample) + if metadata is None: + metadata = {} + + # time for some good ol fashioned validation + self._ensure_jsonable(messages, 'messages') + self._ensure_jsonable(metadata, 'metadata') + + yield {'messages': messages, 'metadata': metadata} + + @staticmethod + def _ensure_jsonable(obj: Any, label: str) -> None: + """Raise ValueError if obj cannot be round-tripped through JSON.""" + try: + json.loads(json.dumps(obj)) + except Exception as e: + log.error(f'Error converting {label} to JSON: {e}') + log.error(f'{label}: {obj}') + raise ValueError(f'{label} is not JSON-serializable') from e + + +def main( + dataset_path: str, + compression: str, + local_dir: str, + hashes: list[str], + splits: list[str], + subset: str | None = None, +): + num_written = 0 + for split in splits: + with MDSWriter( + columns={ + 'messages': 'json', + 'metadata': 'json', + }, + out=os.path.join(local_dir, split), + compression=compression, + hashes=hashes, + ) as out: + dataset = UnifiedMessagesDataset( + dataset_path=dataset_path, + split=split, + subset=subset, + ) + log.info('Converting to MDS format') + for sample in dataset: + num_written += 1 + out.write(sample) + log.info(f'Finished writing {num_written} samples') + log.info(f'Dataset has {num_written} samples') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--dataset_path', + type=str, + required=True, + help='Path to the dataset to process', + ) + parser.add_argument('--compression', type=str, default='zstd') + parser.add_argument('--local_dir', type=str, required=True) + parser.add_argument( + '--hashes', + type=str, + nargs='+', + default=['sha1', 'xxh64'], + ) + parser.add_argument('--subset', type=str, default=None) + parser.add_argument('--splits', type=str, nargs='+', default=['train']) + + args = parser.parse_args() + hf_token = os.environ.get('HF_TOKEN') + main( + dataset_path=args.dataset_path, + compression=args.compression, + local_dir=args.local_dir, + hashes=args.hashes, + splits=args.splits, + subset=args.subset, + ) diff --git a/scripts/data/messages_preprocessing_utils.py b/scripts/data/messages_preprocessing_utils.py new file mode 100644 index 00000000..316bcf86 --- /dev/null +++ b/scripts/data/messages_preprocessing_utils.py @@ -0,0 +1,48 @@ +# Copyright 2025 MosaicML ComposeRL authors +# SPDX-License-Identifier: Apache-2.0 + +"""Preprocessing functions for the messages dataset. + +Each preprocessing function should return a tuple of (messages, metadata). +Messages should be a list of dictionaries with a 'role' key and a 'content' key. +Metadata should be a dictionary with any additional metadata. If there is no +metadata, then return an empty dictionary. Both the messages and metadata must +be json serializable. +""" + +from typing import Any + +from compose_rl.utils.rlvr_utils import ( + extract_gsm8k_answer, + extract_math_answer, + prepare_gsm8k_prompt, + prepare_math_prompt, +) + + +def prepare_gsm8k_messages( + sample: Any, +) -> tuple[list[dict[str, str]], dict[str, str | None]]: + user_prompt = prepare_gsm8k_prompt(sample) + verified_answer = extract_gsm8k_answer(sample) + messages = [ + { + 'role': 'user', + 'content': user_prompt, + }, + ] + return messages, {'verified_answer': verified_answer} + + +def prepare_math_messages( + sample: Any, +) -> tuple[list[dict[str, str]], dict[str, str | None]]: + user_prompt = prepare_math_prompt(sample) + verified_answer = extract_math_answer(sample) + messages = [ + { + 'role': 'user', + 'content': user_prompt, + }, + ] + return messages, {'verified_answer': verified_answer} diff --git a/scripts/data/unified_tokenize_dataset.py b/scripts/data/unified_tokenize_dataset.py index 2878497b..1c7383b1 100644 --- a/scripts/data/unified_tokenize_dataset.py +++ b/scripts/data/unified_tokenize_dataset.py @@ -42,8 +42,9 @@ def __init__( split: str, tokenizer: PreTrainedTokenizerBase, max_length: int, - dataset_type: Literal['preference', 'single_prompt', - 'verifiable_answers'], + dataset_type: Literal['preference', 'single_prompt', 'single_message', + 'verifiable_answers', 'messages_with_answer', + 'classifier'], subset: str | None = None, token: str | None = None, ): @@ -67,7 +68,9 @@ def __init__( token=token, ) - def __iter__(self) -> Iterator[dict[str, bytes]]: + def __iter__( + self, + ) -> Iterator[dict[str, Any]]: for sample in self.hf_dataset: if self.dataset_type == 'preference': yield self._process_preference_sample(sample) @@ -75,10 +78,28 @@ def __iter__(self) -> Iterator[dict[str, bytes]]: result = self._process_single_prompt_sample(sample) if result is not None: yield result + elif self.dataset_type == 'single_message': + result = self._process_single_prompt_sample( + sample, + return_messages=True, + ) + if result is not None: + # delete the prompt from the results since it's not needed + result.pop('prompt') + yield result elif self.dataset_type == 'verifiable_answers': result = self._process_verifiable_answer_sample(sample) if result is not None: yield result + elif self.dataset_type == 'messages_with_answer': + result = self._process_verifiable_answer_sample( + sample, + return_messages=True, + ) + if result is not None: + # delete the prompt from the results since it's not needed + result.pop('prompt') + yield result elif self.dataset_type == 'classifier': yield self._process_classifier_sample(sample) @@ -105,11 +126,16 @@ def _process_preference_sample(self, sample: Any): 'rejected': np.asarray(curr_rejected).tobytes(), } - def _process_single_prompt_sample(self, sample: Any): + def _process_single_prompt_sample( + self, + sample: Any, + return_messages: bool = False, + ): """Process a prompt sample. Args: sample (Any): a sample from the dataset + return_messages (bool): whether to include chat-ml messages in the output """ prompt = sample['prompt'] messages = [{ @@ -118,6 +144,7 @@ def _process_single_prompt_sample(self, sample: Any): 'content': f'Can you summarize the following content in 50 words or less: {prompt}', }] + encoded_prompt = self.tokenizer.apply_chat_template( messages, tokenize=True, @@ -127,7 +154,10 @@ def _process_single_prompt_sample(self, sample: Any): if len(encoded_prompt) > self.max_length: return None - return {'prompt': np.asarray(encoded_prompt).tobytes()} + output = {'prompt': np.asarray(encoded_prompt).tobytes()} + if return_messages: + output['messages'] = messages # type: ignore + return output def _process_classifier_sample(self, sample: Any): """A dummy process a classifier sample. @@ -169,13 +199,18 @@ def _get_processing_fn_from_dataset(self): return prompt_fn, answer_fn - def _process_verifiable_answer_sample(self, sample: Any): + def _process_verifiable_answer_sample( + self, + sample: Any, + return_messages: bool = False, + ): """Process a prompt sample and extract the answer. This function is currently hard-coded for the GSM8K dataset. Args: sample (Any): a sample from the dataset + return_messages (bool): whether to include chat-ml messages in the output """ prompt_fn, answer_fn = self._get_processing_fn_from_dataset() @@ -207,10 +242,13 @@ def _process_verifiable_answer_sample(self, sample: Any): ) return None - return { + output = { 'prompt': np.asarray(encoded_prompt).tobytes(), 'verified_answer': verified_answer, } + if return_messages: + output['messages'] = messages + return output def _check_for_encoding(self, sample: str) -> bool: """Check if a sample is encodable by streaming. @@ -247,7 +285,9 @@ def main( hashes: list[str], splits: list[str], tokenizer_name: str, - dataset_type: Literal['preference', 'single_prompt', 'verifiable_answers'], + dataset_type: Literal['preference', 'single_prompt', 'single_message', + 'verifiable_answers', 'messages_with_answer', + 'classifier'], max_length: int = 2048, subset: str | None = None, token: str | None = None, @@ -260,6 +300,9 @@ def main( 'single_prompt': { 'prompt': 'bytes', }, + 'single_message': { + 'messages': 'json', + }, 'verifiable_answers': { 'prompt': 'bytes', 'verified_answer': 'str', @@ -268,6 +311,10 @@ def main( 'input': 'bytes', 'label': 'bytes', }, + 'messages_with_answer': { + 'messages': 'json', + 'verified_answer': 'str', + }, }[dataset_type] tokenizer = AutoTokenizer.from_pretrained( @@ -338,6 +385,8 @@ def main( 'single_prompt', 'classifier', 'verifiable_answers', + 'messages_with_answer', + 'single_message', ], required=True, help='Type of dataset to process', diff --git a/scripts/launch_composer_ray.py b/scripts/launch_composer_ray.py index 7ef0b657..b093c220 100644 --- a/scripts/launch_composer_ray.py +++ b/scripts/launch_composer_ray.py @@ -310,8 +310,10 @@ def reassign_train_and_inference_ranks( os.environ['MASTER_PORT'] = master_port # Adding a ray sync actor on global rank 0 to make it work - sync_actor = SyncActor.options(name='sync_actor', - namespace='default').remote() + sync_actor = SyncActor.options( # type: ignore + name='sync_actor', + namespace='default', + ).remote() log.info('after start ray nodes') diff --git a/tests/common/__init__.py b/tests/common/__init__.py index 1230c821..a815cb38 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -5,6 +5,7 @@ FineGrainedPreference, PairwisePreference, PromptDataset, + VerifiableMessagesDataset, VerifiablePromptDataset, ) from tests.common.markers import device, world_size @@ -14,6 +15,7 @@ 'FineGrainedPreference', 'PromptDataset', 'VerifiablePromptDataset', + 'VerifiableMessagesDataset', 'device', 'world_size', ] diff --git a/tests/common/datasets.py b/tests/common/datasets.py index 98da559e..09a49804 100644 --- a/tests/common/datasets.py +++ b/tests/common/datasets.py @@ -89,3 +89,30 @@ def __getitem__(self, index: int): 'prompt_len': torch.Tensor([self.prompt_len]).to(torch.int64), 'verified_answer': '1', } + + +class VerifiableMessagesDataset(Dataset): + + def __init__(self, size: int = 8, prompt_len: int = 5): + self.size = size + self.prompt_len = prompt_len + + def __len__(self): + return self.size + + def __getitem__(self, index: int): + messages = [{ + 'role': 'user', + 'content': 'F' * self.prompt_len, + }] # bit of a hack, but it works + mock_prompt = torch.ones((len(messages[0]['content']),)).int() + return { + 'messages': + messages, + 'prompt': + mock_prompt, + 'prompt_len': + torch.Tensor([len(messages[0]['content'])]).to(torch.int64), + 'verified_answer': + 'Paris', + } diff --git a/tests/test_online.py b/tests/test_online.py index 27aba897..45f81461 100644 --- a/tests/test_online.py +++ b/tests/test_online.py @@ -24,8 +24,16 @@ ) from compose_rl.algorithms.online.model_methods import OnPolicyEnum from compose_rl.algorithms.online.modeling_hf import ComposerHFPolicy -from compose_rl.data import prompt_dataset_collate_fn -from tests.common import PromptDataset, VerifiablePromptDataset, world_size +from compose_rl.data import ( + messages_dataset_collate_fn, + prompt_dataset_collate_fn, +) +from tests.common import ( + PromptDataset, + VerifiableMessagesDataset, + VerifiablePromptDataset, + world_size, +) def test_hf_ppo_model_construction( @@ -77,19 +85,31 @@ def test_hf_ppo_policy_construction( @pytest.mark.parametrize('model_type', ['mpt', 'hf']) -@pytest.mark.parametrize('dataset_type', ['prompt', 'verifiable_prompt']) +@pytest.mark.parametrize( + 'dataset_type', + ['prompt', 'verifiable_prompt', 'verifiable_messages'], +) def test_model_forward( tiny_gpt2_tokenizer: PreTrainedTokenizerBase, model_type: str, dataset_type: str, ): prompt_len = 10 - data_class = PromptDataset if dataset_type == 'prompt' else VerifiablePromptDataset - dataset = data_class(prompt_len=prompt_len) + if dataset_type == 'prompt': + dataset = PromptDataset(prompt_len=prompt_len) + dataset_collator = prompt_dataset_collate_fn + elif dataset_type == 'verifiable_prompt': + dataset = VerifiablePromptDataset(prompt_len=prompt_len) + dataset_collator = prompt_dataset_collate_fn + elif dataset_type == 'verifiable_messages': + dataset = VerifiableMessagesDataset(prompt_len=prompt_len) + dataset_collator = messages_dataset_collate_fn + else: + raise ValueError(f'Unknown dataset type: {dataset_type}') dataloader = DataLoader( dataset, collate_fn=partial( - prompt_dataset_collate_fn, + dataset_collator, tiny_gpt2_tokenizer, 32, ), diff --git a/yamls/local_grpo.yaml b/yamls/local_grpo.yaml index 504128e5..9ad1ea02 100644 --- a/yamls/local_grpo.yaml +++ b/yamls/local_grpo.yaml @@ -114,7 +114,7 @@ tokenizer: truncation: true train_loader: - name: prompt + name: messages dataset: # local: # local path split: train diff --git a/yamls/local_ppo.yaml b/yamls/local_ppo.yaml index d89f6701..4ee648ad 100644 --- a/yamls/local_ppo.yaml +++ b/yamls/local_ppo.yaml @@ -112,7 +112,7 @@ tokenizer: truncation: true train_loader: - name: prompt + name: messages dataset: # local: # local path split: train