diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index 121501579..b76601408 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -57,7 +57,7 @@ def __init__(self, config, batch_sampler) -> None: self.lr = config.adam.lr # smapler state - if batch_sampler: + if batch_sampler is not None: self.init_batch_sampler(batch_sampler) # tgs statistic diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index e7f581dca..20e9c24d9 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -1,4 +1,5 @@ # Copyright (c) InternLM. All rights reserved. +import subprocess from functools import partial import torch.distributed as dist @@ -6,6 +7,11 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.data.megatron.batch_sampler import MegatronBatchSampler +from internlm.data.megatron.collaters import megatron_collate_fn +from internlm.data.megatron.dataset import build_megatron_dataset +from internlm.data.mocked.batch_sampler import MockedSequentialBatchSampler +from internlm.data.mocked.dataset import MockedDataset from internlm.data.streaming.batch_sampler import StreamingStaticBatchSampler from internlm.data.streaming.collaters import streaming_packed_collate_fn from internlm.data.streaming.dataset import ( @@ -139,6 +145,62 @@ def get_streaming_train_loader_items(data_cfg): return train_ds, train_sampler, streaming_packed_collate_fn +def get_megatron_train_loader_items(data_cfg): + try: + from internlm.data.megatron import helpers # noqa # pylint: disable=W0611 + except ImportError: + if gpc.is_rank_for_log(): + subprocess.run( # noqa # pylint: disable=W1510 + [ + "g++", + "-O3", + "-shared", + "-std=c++11", + "-fPIC", + "-fdiagnostics-color", + "-I$(python3-config --includes)", + "-I$(python3 -m pybind11 --includes)", + "internlm/data/megatron/helpers.cpp", + "-o", + "internlm/data/megatron/helpers.so", + ] + ) + train_ds = build_megatron_dataset( + data_prefix=data_cfg.train_folder, + data_impl=data_cfg.get("data_impl", "infer"), + splits_string="1.0, 0.0, 0.0", + train_valid_test_num_samples=[9600000, 0, 0], + seq_len=data_cfg.seq_len, + seed=data_cfg.get("seed", 1024), + skip_warmup=True, + ) + + train_sampler = MegatronBatchSampler( + total_samples=len(train_ds), + consumed_samples=0, + batch_size=data_cfg.micro_num * data_cfg.micro_bsz, + drop_last=True, + ) + + train_collate_fn = partial( + megatron_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len + ) + + return train_ds, train_sampler, train_collate_fn + + +def get_mock_train_loader_items(data_cfg): + train_ds = MockedDataset( + data_dir=data_cfg.train_folder, # defined the path of mocked data + micro_bsz=data_cfg.micro_bsz, + seq_len=data_cfg.seq_len, + mocked_steps=data_cfg.mocked_steps, # defined the steps of mocked data + ) + train_sampler = MockedSequentialBatchSampler(train_ds, data_cfg.micro_num) + train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.seq_len * data_cfg.micro_bsz) + return train_ds, train_sampler, train_collate_fn + + def build_train_loader_with_data_type(): """ Build and return the training data loader based on data type. @@ -155,6 +217,14 @@ def build_train_loader_with_data_type(): train_ds, train_sampler, train_collate_fn = get_streaming_train_loader_items(data_cfg) # TODO: support more dataset_types dataset_types = ["en"] + elif data_cfg.type == DataType.megatron.name: + train_ds, train_sampler, train_collate_fn = get_megatron_train_loader_items(data_cfg) + # TODO: support more dataset_types + dataset_types = ["en"] + elif data_cfg.type == DataType.mocked.name: + train_ds, train_sampler, train_collate_fn = get_mock_train_loader_items(data_cfg) + # TODO: support more dataset_types + dataset_types = ["en"] else: raise ValueError(f"dataset type {data_cfg.type} is not supported") @@ -176,8 +246,13 @@ def build_valid_loader_with_data_type(): data_cfg = gpc.config.data - # TODO: support streaming dataset for validation - if data_cfg.type in [DataType.tokenized.name, DataType.streaming.name]: + # TODO: For validation, currenlt we only support dummy dataset for streaming/megatron/mocked DataType. + if data_cfg.type in [ + DataType.tokenized.name, + DataType.streaming.name, + DataType.megatron.name, + DataType.mocked.name, + ]: valid_ds, valid_collate_fn = get_tokenized_valid_loader_items(data_cfg) else: raise ValueError(f"dataset type {data_cfg.type} is not supported") diff --git a/internlm/data/megatron/__init__.py b/internlm/data/megatron/__init__.py new file mode 100644 index 000000000..5e4475969 --- /dev/null +++ b/internlm/data/megatron/__init__.py @@ -0,0 +1,9 @@ +from .batch_sampler import MegatronBatchSampler +from .collaters import megatron_collate_fn +from .dataset import build_megatron_dataset + +__all__ = [ + "MegatronBatchSampler", + "build_megatron_dataset", + "megatron_collate_fn", +] diff --git a/internlm/data/megatron/batch_sampler.py b/internlm/data/megatron/batch_sampler.py new file mode 100644 index 000000000..049cfcf7e --- /dev/null +++ b/internlm/data/megatron/batch_sampler.py @@ -0,0 +1,62 @@ +import copy +import math + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc + + +class MegatronBatchSampler: + """ + MegatronBatchSampler + """ + + def __init__(self, total_samples, consumed_samples, batch_size, drop_last=True): + # Keep a copy of input params for later use. + self.total_samples = total_samples + self.consumed_samples = consumed_samples + self.batch_size = batch_size + self.drop_last = drop_last + + self.dp_rank = gpc.get_local_rank(ParallelMode.DATA) + self.dp_size = gpc.get_world_size(ParallelMode.DATA) + + # Sanity checks. + assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples) + assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format( + self.consumed_samples, self.total_samples + ) + assert self.batch_size > 0 + assert self.dp_size > 0 + assert self.dp_rank < self.dp_size, "dp_rank should be smaller than dp_size: {}, " "{}".format( + self.dp_rank, self.dp_size + ) + + def __len__(self): + if self.drop_last and self.total_samples % self.dp_size != 0: + return math.ceil(self.total_samples - self.dp_size) / self.dp_size + else: + return math.ceil(self.total_samples / self.dp_size) + + def get_start_end_idx(self): + start_idx = self.dp_rank * self.batch_size + end_idx = start_idx + self.batch_size + return start_idx, end_idx + + def __iter__(self): + batch = [] + # Last batch will be dropped if drop_last is not set False + for idx in range(self.consumed_samples, self.total_samples): + batch.append(idx) + if len(batch) == self.batch_size * self.dp_size: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + batch = [] + + # Check the last partial batch and see drop_last is set + if len(batch) > 0 and not self.drop_last: + start_idx, end_idx = self.get_start_end_idx() + yield batch[start_idx:end_idx] + + # TODO: implement copy method that compatible with InternEvo trainstate + def copy(self): + return copy.deepcopy(self) diff --git a/internlm/data/megatron/collaters.py b/internlm/data/megatron/collaters.py new file mode 100644 index 000000000..252bc289e --- /dev/null +++ b/internlm/data/megatron/collaters.py @@ -0,0 +1,49 @@ +import torch + + +def megatron_collate_fn(batch, micro_num, micro_bsz, seq_len): + + input_ids_result = [[] for _ in range(micro_num)] + labels_result = [[] for _ in range(micro_num)] + cu_seqlens = [] + cu_seqlens_list = [] + indexes = [] + indexes_list = [] + + for i, item in enumerate(batch): + assert i < micro_num * micro_bsz + seq_len_list = item["text"] + assert len(seq_len_list) == seq_len + 1 + + micro_bsz_index = i % micro_bsz + micro_num_index = i // micro_bsz + + input_ids_result[micro_num_index].append(seq_len_list[:-1]) + labels_result[micro_num_index].append(seq_len_list[1:]) + + cu_seqlens.append(seq_len * micro_bsz_index) + indexes = indexes + list(range(seq_len)) + + if micro_bsz_index == micro_bsz - 1: + input_ids_result[micro_num_index] = torch.cat( + [torch.from_numpy(arr).long() for arr in input_ids_result[micro_num_index]], dim=0 + ) + labels_result[micro_num_index] = torch.cat( + [torch.from_numpy(arr).long() for arr in labels_result[micro_num_index]], dim=0 + ) + cu_seqlens.append(seq_len * micro_bsz) + cu_seqlens_list.append(torch.IntTensor(cu_seqlens)) + cu_seqlens = [] + indexes_list.append(torch.IntTensor(indexes)) + indexes = [] + + input_ids = torch.stack(input_ids_result) + labels = torch.stack(labels_result) + indexes = torch.stack(indexes_list) + + return { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens_list, + "indexes": indexes, + "type_ids": torch.zeros(micro_num, micro_bsz * seq_len, dtype=torch.int64), + }, labels diff --git a/internlm/data/megatron/dataset.py b/internlm/data/megatron/dataset.py new file mode 100644 index 000000000..a7e2d04ca --- /dev/null +++ b/internlm/data/megatron/dataset.py @@ -0,0 +1,843 @@ +import hashlib +import os +import struct +import time +from functools import lru_cache +from itertools import accumulate + +import numpy as np +import torch + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float64, + 7: np.float32, + 8: np.uint16, +} + + +def print_rank_0(message): + """If distributed is initialized, print only on rank 0.""" + if gpc.is_rank_for_log(): + print(message, flush=True) + + +def code(dtype): + for k, v in dtypes.items(): + if v == dtype: + return k + raise ValueError(dtype) + + +def index_file_path(prefix_path): + return prefix_path + ".idx" + + +def data_file_path(prefix_path): + return prefix_path + ".bin" + + +def read_longs(f, n): + a = np.empty(n, dtype=np.int64) + f.readinto(a) + return a + + +def _warmup_mmap_file(path): + with open(path, "rb") as stream: + while stream.read(100 * 1024 * 1024): + pass + + +def _build_shuffle_idx(num_samples, total_size, np_rng): + """Build the range [0, size) and shuffle.""" + dtype_ = np.uint32 + if total_size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + + shuffle_idx_first = np.arange(start=0, stop=num_samples, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_first) + if num_samples == total_size: + return shuffle_idx_first + + shuffle_idx_last = np.arange(start=num_samples, stop=total_size, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_last) + + return np.concatenate((shuffle_idx_first, shuffle_idx_last)) + + +def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): + """Build an array with length = number-of-epochs * number-of-dcuments. + Each index is mapped to a corresponding document.""" + if not separate_last_epoch or num_epochs == 1: + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + doc_idx_first = _build_doc_idx(documents, num_epochs - 1, np_rng, False) + doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) + return np.concatenate((doc_idx_first, doc_idx_last)) + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence lenght, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_index_mappings( + name, data_prefix, documents, sizes, splits_string, num_samples, seq_length, seed, *, data_cache_path +): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + desc = "GPT Dataset\n\n" + desc += f"Data prefix {data_prefix}\n" + desc += f"Dataset name {name}\n" + desc += f"Number of samples {num_samples}\n" + desc += f"Sequence length {seq_length}\n" + desc += f"Random seed {seed}\n" + desc += f"Split {splits_string}\n" + desc_hash = hashlib.md5(desc.encode("utf-8")).hexdigest() + desc_filename = desc_hash + ".dsc" + doc_idx_filename = desc_hash + "_doc_idx.npy" + sample_idx_filename = desc_hash + "_sample_idx.npy" + shuffle_idx_filename = desc_hash + "_shuffle_idx.npy" + + # Look for cache in main data dir first to avoid unnecessary + # duplication, then look in data-cache-path if specified, + # If nothing is found, use the last path looked in + build_indices = True + prefixes = [os.path.join(os.path.dirname(data_prefix), "index-cache")] + if data_cache_path is not None: + prefixes.append(data_cache_path) + for prefix in prefixes: + idx_path = { + "desc": os.path.join(prefix, desc_filename), + "doc": os.path.join(prefix, doc_idx_filename), + "sample": os.path.join(prefix, sample_idx_filename), + "shuffle": os.path.join(prefix, shuffle_idx_filename), + } + for f in idx_path.values(): + if not os.path.isfile(f): + break + else: + # Found our files! + build_indices = False + break + data_cache_dir = os.path.dirname(idx_path["desc"]) + data_cache_success = True + + # Build the indexed mapping if not exist. + if build_indices and gpc.is_rank_for_log(): + print_rank_0(" > WARNING: could not find index map files, building " "the indices on rank 0 ...") + + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + + # If we need only one epoch, then separating last epoch does + # not mean anything. + if num_epochs == 1: + separate_last_epoch = False + print(" > only one epoch required, setting " "separate_last_epoch to False", flush=True) + + else: + # Get the number of samples for the last epoch + num_samples_from_epochs_minus_one = ((num_epochs - 1) * tokens_per_epoch - 1) // seq_length + last_epoch_num_samples = num_samples - num_samples_from_epochs_minus_one + assert last_epoch_num_samples >= 0, "last epoch number of samples should be non-negative." + num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length + assert last_epoch_num_samples <= ( + num_samples_per_epoch + 1 + ), "last epoch number of samples exceeded max value." + # If we have less than 80% of the samples for the last epoch, + # seperate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can + # be adjusted if needed. + separate_last_epoch = last_epoch_num_samples < int(0.80 * num_samples_per_epoch) + if separate_last_epoch: + string = ( + " > last epoch number of samples ({}) is smaller " + "than 80% of number of samples per epoch ({}), " + "setting separate_last_epoch to True" + ) + else: + string = ( + " > last epoch number of samples ({}) is larger " + "than 80% of number of samples per epoch ({}), " + "setting separate_last_epoch to False" + ) + print(string.format(last_epoch_num_samples, num_samples_per_epoch), flush=True) + + try: + os.makedirs(data_cache_dir, exist_ok=True) + + # description + with open(idx_path["desc"], "wt") as fd: + fd.write(desc) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch) + np.save(idx_path["doc"], doc_idx, allow_pickle=True) + print_rank_0( + " > elasped time to build and save doc-idx mapping " "(seconds): {:4f}".format(time.time() - start_time) + ) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + from internlm.data.megatron import helpers + + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch) + np.save(idx_path["sample"], sample_idx, allow_pickle=True) + print_rank_0( + " > elasped time to build and save sample-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) + ) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, sample_idx.shape[0] - 1, np_rng) + np.save(idx_path["shuffle"], shuffle_idx, allow_pickle=True) + print_rank_0( + " > elasped time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) + except OSError: + print(f"There was an error trying to create the data cache directory ({data_cache_dir})") + print('or a file in it. This defaults to a directory "index-cache" within the directory') + print("the data files are in and can be set with the --data-cache-path argument. Please") + print("ensure you have write access to this directory or specify one that you do have") + print("write access to.") + data_cache_success = False + + counts = torch.cuda.LongTensor([data_cache_success]) + + if gpc.is_using_parallel_mode(ParallelMode.DATA): + torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.DATA)) + + if gpc.is_using_parallel_mode(ParallelMode.PIPELINE): + torch.distributed.all_reduce(counts, group=gpc.get_group(ParallelMode.PIPELINE)) + + assert counts[0].item() == ( + gpc.get_world_size(ParallelMode.GLOBAL) // gpc.get_world_size(ParallelMode.TENSOR) + ), "Data index creation unsuccessful!" + + # Load mappings. + start_time = time.time() + print_rank_0(f" > loading doc-idx mapping from {idx_path['doc']}") + doc_idx = np.load(idx_path["doc"], allow_pickle=True, mmap_mode="r") + + print_rank_0(f" > loading sample-idx mapping from {idx_path['sample']}") + sample_idx = np.load(idx_path["sample"], allow_pickle=True, mmap_mode="r") + + print_rank_0(f" > loading shuffle-idx mapping from {idx_path['shuffle']}") + shuffle_idx = np.load(idx_path["shuffle"], allow_pickle=True, mmap_mode="r") + + print_rank_0(" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)) + print_rank_0(" total number of samples: {}".format(sample_idx.shape[0])) + print_rank_0(" total number of epochs: {}".format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx, desc, desc_hash + + +class GPTDataset(torch.utils.data.Dataset): + """ + GPTDataset + """ + + def __init__( + self, + name, + data_prefix, + documents, + indexed_dataset, + splits_string, + num_samples, + seq_length, + seed, + return_doc_ids=False, + *, + data_cache_path=None, + ): + + self.name = name + self.indexed_dataset = indexed_dataset + self.return_doc_ids = return_doc_ids + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < indexed_dataset.sizes.shape[0] + + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx, self.desc, self.desc_hash = _build_index_mappings( + self.name, + data_prefix, + documents, + self.indexed_dataset.sizes, + splits_string, + num_samples, + seq_length, + seed, + data_cache_path=data_cache_path, + ) + + def __len__(self): + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + return self.sample_idx.shape[0] - 1 + + def __getitem__(self, idx): + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # If we are within the same document, just extract the chunk. + doc_ids = [] + if doc_index_f == doc_index_l: + doc_ids.append(self.doc_idx[doc_index_f]) + sample = self.indexed_dataset.get( + self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1 + ) + else: + # Otherwise, get the rest of the initial document. + doc_ids.append(self.doc_idx[doc_index_f]) + sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + doc_ids.append(self.doc_idx[i]) + sample_list.append(self.indexed_dataset.get(self.doc_idx[i])) + # And finally add the relevant portion of last document. + doc_ids.append(self.doc_idx[doc_index_l]) + sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)) + sample = np.concatenate(sample_list) + + if self.return_doc_ids: # for retro preprocessing + return {"text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64)} + else: + return {"text": np.array(sample, dtype=np.int64)} + + +def infer_dataset_impl(path): + if IndexedDataset.exists(path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + if magic == IndexedDataset._HDR_MAGIC: + return "cached" + elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]: + return "mmap" + else: + return None + else: + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + + +def make_dataset(path, impl, skip_warmup=False): + if not IndexedDataset.exists(path): + print(f"Dataset does not exist: {path}") + print("Path should be a basename that both .idx and .bin can be appended to get full filenames.") + return None + if impl == "infer": + impl = infer_dataset_impl(path) + if impl == "lazy" and IndexedDataset.exists(path): + return IndexedDataset(path) + elif impl == "cached" and IndexedDataset.exists(path): + return IndexedCachedDataset(path) + elif impl == "mmap" and MMapIndexedDataset.exists(path): + return MMapIndexedDataset(path, skip_warmup) + print(f"Unknown dataset implementation: {impl}") + return None + + +class IndexedDataset(torch.utils.data.Dataset): + """Loader for IndexedDataset""" + + _HDR_MAGIC = b"TNTIDX\x00\x00" + + def __init__(self, path): + super().__init__() + self.path = path + self.data_file = None + self.read_index(path) + + def read_index(self, path): + with open(index_file_path(path), "rb") as f: + magic = f.read(8) + assert magic == self._HDR_MAGIC, ( + "Index file doesn't match expected format. " "Make sure that --dataset-impl is configured properly." + ) + version = f.read(8) + assert struct.unpack("= self._len: + raise IndexError("index out of range") + + def __del__(self): + if self.data_file: + self.data_file.close() + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if not self.data_file: + self.read_data(self.path) + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + return a + elif isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + if step != 1: + raise ValueError("Slices into indexed_dataset must be contiguous") + sizes = self.sizes[self.dim_offsets[start] : self.dim_offsets[stop]] + size = sum(sizes) + a = np.empty(size, dtype=self.dtype) + self.data_file.seek(self.data_offsets[start] * self.element_size) + self.data_file.readinto(a) + offsets = list(accumulate(sizes)) + sents = np.split(a, offsets[:-1]) + return sents + + def __len__(self): + return self._len + + def num_tokens(self, index): + return self.sizes[index] + + def size(self, index): + return self.sizes[index] + + @staticmethod + def exists(path): + return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path)) + + @property + def supports_prefetch(self): + return False # avoid prefetching to save memory + + +class IndexedCachedDataset(IndexedDataset): + """ + IndexedCachedDataset + """ + + def __init__(self, path): + super().__init__(path) + self.cache = None + self.cache_index = {} + + @property + def supports_prefetch(self): + return True + + def prefetch(self, indices): + if all(i in self.cache_index for i in indices): + return + if not self.data_file: + self.read_data(self.path) + indices = sorted(set(indices)) + total_size = 0 + for i in indices: + total_size += self.data_offsets[i + 1] - self.data_offsets[i] + self.cache = np.empty(total_size, dtype=self.dtype) + ptx = 0 + self.cache_index.clear() + for i in indices: + self.cache_index[i] = ptx + size = self.data_offsets[i + 1] - self.data_offsets[i] + a = self.cache[ptx : ptx + size] + self.data_file.seek(self.data_offsets[i] * self.element_size) + self.data_file.readinto(a) + ptx += size + if self.data_file: + # close and delete data file after prefetch so we can pickle + self.data_file.close() + self.data_file = None + + # @lru_cache(maxsize=8) + def __getitem__(self, idx): + if isinstance(idx, int): + i = idx + self.check_index(i) + tensor_size = self.sizes[self.dim_offsets[i] : self.dim_offsets[i + 1]] + a = np.empty(tensor_size, dtype=self.dtype) + ptx = self.cache_index[i] + np.copyto(a, self.cache[ptx : ptx + a.size]) + return a + elif isinstance(idx, slice): + # Hack just to make this work, can optimizer later if necessary + sents = [] + for i in range(*idx.indices(len(self))): + sents.append(self[i]) + return sents + + +class MMapIndexedDataset(torch.utils.data.Dataset): + """ + MMapIndexedDataset + """ + + class Index(object): + """ + Index + """ + + _HDR_MAGIC = b"MMIDIDX\x00\x00" + + @classmethod + def writer(cls, path, dtype): + class _Writer(object): + """ + _Writer + """ + + def __enter__(self): + self._file = open(path, "wb") + + self._file.write(cls._HDR_MAGIC) + self._file.write(struct.pack(" building dataset index ...") + + start_time = time.time() + indexed_dataset = make_dataset(data_prefix, data_impl, skip_warmup) + print_rank_0(" > finished creating indexed dataset in {:4f} " "seconds".format(time.time() - start_time)) + print_rank_0(" number of documents: {}".format(indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +def get_train_valid_test_split_(splits_string, size): + """Get dataset splits from comma or '/' separated string list.""" + + splits = [] + if splits_string.find(",") != -1: + splits = [float(s) for s in splits_string.split(",")] + elif splits_string.find("/") != -1: + splits = [float(s) for s in splits_string.split("/")] + else: + splits = [float(splits_string)] + while len(splits) < 3: + splits.append(0.0) + splits = splits[:3] + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits = [split / splits_sum for split in splits] + splits_index = [0] + for index, split in enumerate(splits): + splits_index.append(splits_index[index] + int(round(split * float(size)))) + diff = splits_index[-1] - size + for index in range(1, len(splits_index)): + splits_index[index] -= diff + assert len(splits_index) == 4 + assert splits_index[-1] == size + return splits_index + + +def build_megatron_dataset( + data_prefix, + data_impl, + splits_string, + train_valid_test_num_samples, + seq_len, + seed, + skip_warmup, + return_doc_ids=False, + *, + data_cache_path=None, +): + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, data_impl, skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + + # Print stats about the splits. + print_rank_0(" > dataset split:") + + def print_split_stats(index, name): + print_rank_0(" {}:".format(name)) + print_rank_0( + " document indices in [{}, {}) total of {} " + "documents".format(splits[index], splits[index + 1], splits[index + 1] - splits[index]) + ) + + print_split_stats(0, "train") + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32) + dataset = GPTDataset( + name, + data_prefix, + documents, + indexed_dataset, + splits_string, + train_valid_test_num_samples[index], + seq_len, + seed, + return_doc_ids, + data_cache_path=data_cache_path, + ) + return dataset + + train_dataset = build_dataset(0, "train") + + return train_dataset diff --git a/internlm/data/megatron/helpers.cpp b/internlm/data/megatron/helpers.cpp new file mode 100644 index 000000000..09f5f9762 --- /dev/null +++ b/internlm/data/megatron/helpers.cpp @@ -0,0 +1,701 @@ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +/* Helper methods for fast index mapping builds */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +const int32_t LONG_SENTENCE_LEN = 512; + + +void build_blending_indices(py::array_t& dataset_index, + py::array_t& dataset_sample_index, + const py::array_t& weights, + const int32_t num_datasets, + const int64_t size, const bool verbose) { + /* Given multiple datasets and a weighting array, build samples + such that it follows those wieghts.*/ + + if (verbose) { + std::cout << "> building indices for blendable datasets ..." << std::endl; + } + + // Get the pointer access without the checks. + auto dataset_index_ptr = dataset_index.mutable_unchecked<1>(); + auto dataset_sample_index_ptr = dataset_sample_index.mutable_unchecked<1>(); + auto weights_ptr = weights.unchecked<1>(); + + // Initialize buffer for number of samples used for each dataset. + int64_t current_samples[num_datasets]; + for(int64_t i = 0; i < num_datasets; ++i) { + current_samples[i] = 0; + } + + // For each sample: + for(int64_t sample_idx = 0; sample_idx < size; ++sample_idx) { + + // Determine where the max error in sampling is happening. + auto sample_idx_double = std::max(static_cast(sample_idx), 1.0); + int64_t max_error_index = 0; + double max_error = weights_ptr[0] * sample_idx_double - + static_cast(current_samples[0]); + for (int64_t dataset_idx = 1; dataset_idx < num_datasets; ++dataset_idx) { + double error = weights_ptr[dataset_idx] * sample_idx_double - + static_cast(current_samples[dataset_idx]); + if (error > max_error) { + max_error = error; + max_error_index = dataset_idx; + } + } + + // Populate the indices. + dataset_index_ptr[sample_idx] = static_cast(max_error_index); + dataset_sample_index_ptr[sample_idx] = current_samples[max_error_index]; + + // Update the total samples. + current_samples[max_error_index] += 1; + + } + + // print info + if (verbose) { + std::cout << " > sample ratios:" << std::endl; + for (int64_t dataset_idx = 0; dataset_idx < num_datasets; ++dataset_idx) { + auto ratio = static_cast(current_samples[dataset_idx]) / + static_cast(size); + std::cout << " dataset " << dataset_idx << ", input: " << + weights_ptr[dataset_idx] << ", achieved: " << ratio << std::endl; + } + } + +} + + +py::array build_sample_idx(const py::array_t& sizes_, + const py::array_t& doc_idx_, + const int32_t seq_length, + const int32_t num_epochs, + const int64_t tokens_per_epoch) { + /* Sample index (sample_idx) is used for gpt2 like dataset for which + the documents are flattened and the samples are built based on this + 1-D flatten array. It is a 2D array with sizes [number-of-samples + 1, 2] + where [..., 0] contains the index into `doc_idx` and [..., 1] is the + starting offset in that document.*/ + + // Consistency checks. + assert(seq_length > 1); + assert(num_epochs > 0); + assert(tokens_per_epoch > 1); + + // Remove bound checks. + auto sizes = sizes_.unchecked<1>(); + auto doc_idx = doc_idx_.unchecked<1>(); + + // Mapping and it's length (1D). + int64_t num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length; + int32_t* sample_idx = new int32_t[2*(num_samples+1)]; + + cout << " using:" << endl << std::flush; + cout << " number of documents: " << + doc_idx_.shape(0) / num_epochs << endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " sequence length: " << seq_length << + endl << std::flush; + cout << " total number of samples: " << num_samples << + endl << std::flush; + + // Index into sample_idx. + int64_t sample_index = 0; + // Index into doc_idx. + int64_t doc_idx_index = 0; + // Begining offset for each document. + int32_t doc_offset = 0; + // Start with first document and no offset. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + + while (sample_index <= num_samples) { + // Start with a fresh sequence. + int32_t remaining_seq_length = seq_length + 1; + while (remaining_seq_length != 0) { + // Get the document length. + auto doc_id = doc_idx[doc_idx_index]; + auto doc_length = sizes[doc_id] - doc_offset; + // And add it to the current sequence. + remaining_seq_length -= doc_length; + // If we have more than a full sequence, adjust offset and set + // remaining length to zero so we return from the while loop. + // Note that -1 here is for the same reason we have -1 in + // `_num_epochs` calculations. + if (remaining_seq_length <= 0) { + doc_offset += (remaining_seq_length + doc_length - 1); + remaining_seq_length = 0; + } else { + // Otherwise, start from the begining of the next document. + ++doc_idx_index; + doc_offset = 0; + } + } + // Record the sequence. + sample_idx[2 * sample_index] = doc_idx_index; + sample_idx[2 * sample_index + 1] = doc_offset; + ++sample_index; + } + + // Method to deallocate memory. + py::capsule free_when_done(sample_idx, [](void *mem_) { + int32_t *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(int32_t); + return py::array(std::vector{num_samples+1, 2}, // shape + {2*byte_size, byte_size}, // C-style contiguous strides + sample_idx, // the data pointer + free_when_done); // numpy array references + +} + + +inline int32_t get_target_sample_len(const int32_t short_seq_ratio, + const int32_t max_length, + std::mt19937& rand32_gen) { + /* Training sample length. */ + if (short_seq_ratio == 0) { + return max_length; + } + const auto random_number = rand32_gen(); + if ((random_number % short_seq_ratio) == 0) { + return 2 + random_number % (max_length - 1); + } + return max_length; +} + + +template +py::array build_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const double short_seq_prob, + const int32_t seed, + const bool verbose, + const int32_t min_num_sent) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(short_seq_prob >= 0.0); + assert(short_seq_prob <= 1.0); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + + // For efficiency, convert probability to ratio. Note: rand() generates int. + int32_t short_seq_ratio = 0; + if (short_seq_prob > 0) { + short_seq_ratio = static_cast(round(1.0 / short_seq_prob)); + } + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " short sequence probability: " << short_seq_prob << + endl << std::flush; + cout << " short sequence ration (1/prob): " << short_seq_ratio << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and it's length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the seed so both iterations produce the same results. + std::mt19937 rand32_gen(seed); + + // Set the flag on second iteration. + second = (iteration == 1); + + // Counters: + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + + // Current map index. + uint64_t map_index = 0; + + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent > 1) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + + // If we have more than two sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + auto target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and if not only one sentence is left in the document. + // and if we have at least two sentneces. + // and if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent > 1) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Check for overflow. + if ((3 * map_index + 2) > + std::numeric_limits::max()) { + cout << "number of samples exceeded maximum " + << "allowed by type int64: " + << std::numeric_limits::max() + << endl; + throw std::overflow_error("Number of samples"); + } + + // Populate the map. + if (second) { + const auto map_index_0 = 3 * map_index; + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(target_seq_len); + } + + // Update indices / counters. + ++map_index; + prev_start_index = sent_index + 1; + target_seq_len = get_target_sample_len(short_seq_ratio, + max_seq_length, + rand32_gen); + seq_len = 0; + num_sent = 0; + } + + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[3*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 3 * i; + const auto j0 = 3 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 3}, // shape + {3*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + + +py::array build_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const double short_seq_prob, + const int seed, + const bool verbose, + const int32_t min_num_sent) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_mapping_impl(docs_, sizes_, num_epochs, + max_num_samples, max_seq_length, + short_seq_prob, seed, verbose, + min_num_sent); + } +} + +template +py::array build_blocks_mapping_impl(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int32_t num_epochs, + const uint64_t max_num_samples, + const int32_t max_seq_length, + const int32_t seed, + const bool verbose, + const bool use_one_sent_blocks) { + /* Build a mapping of (start-index, end-index, sequence-length) where + start and end index are the indices of the sentences in the sample + and sequence-length is the target sequence length. + */ + + // Consistency checks. + assert(num_epochs > 0); + assert(max_seq_length > 1); + assert(seed > 0); + + // Remove bound checks. + auto docs = docs_.unchecked<1>(); + auto sizes = sizes_.unchecked<1>(); + auto titles_sizes = titles_sizes_.unchecked<1>(); + + if (verbose) { + const auto sent_start_index = docs[0]; + const auto sent_end_index = docs[docs_.shape(0) - 1]; + const auto num_sentences = sent_end_index - sent_start_index; + cout << " using:" << endl << std::flush; + cout << " number of documents: " << docs_.shape(0) - 1 << + endl << std::flush; + cout << " sentences range: [" << sent_start_index << + ", " << sent_end_index << ")" << endl << std::flush; + cout << " total number of sentences: " << num_sentences << + endl << std::flush; + cout << " number of epochs: " << num_epochs << + endl << std::flush; + cout << " maximum number of samples: " << max_num_samples << + endl << std::flush; + cout << " maximum sequence length: " << max_seq_length << + endl << std::flush; + cout << " seed: " << seed << endl << + std::flush; + } + + // Mapping and its length (1D). + int64_t num_samples = -1; + DocIdx* maps = NULL; + + // Acceptable number of sentences per block. + int min_num_sent = 2; + if (use_one_sent_blocks) { + min_num_sent = 1; + } + + // Perform two iterations, in the first iteration get the size + // and allocate memory and in the second iteration populate the map. + bool second = false; + for (int32_t iteration=0; iteration<2; ++iteration) { + + // Set the flag on second iteration. + second = (iteration == 1); + + // Current map index. + uint64_t map_index = 0; + + uint64_t empty_docs = 0; + uint64_t one_sent_docs = 0; + uint64_t long_sent_docs = 0; + // For each epoch: + for (int32_t epoch=0; epoch= max_num_samples) { + if (verbose && (!second)) { + cout << " reached " << max_num_samples << " samples after " + << epoch << " epochs ..." << endl << std::flush; + } + break; + } + // For each document: + for (int32_t doc=0; doc<(docs.shape(0) - 1); ++doc) { + + // Document sentences are in [sent_index_first, sent_index_last) + const auto sent_index_first = docs[doc]; + const auto sent_index_last = docs[doc + 1]; + const auto target_seq_len = max_seq_length - titles_sizes[doc]; + + // At the begining of the document previous index is the + // start index. + auto prev_start_index = sent_index_first; + + // Remaining documents. + auto num_remain_sent = sent_index_last - sent_index_first; + + // Some bookkeeping + if ((epoch == 0) && (!second)) { + if (num_remain_sent == 0) { + ++empty_docs; + } + if (num_remain_sent == 1) { + ++one_sent_docs; + } + } + // Detect documents with long sentences. + bool contains_long_sentence = false; + if (num_remain_sent >= min_num_sent) { + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + if (sizes[sent_index] > LONG_SENTENCE_LEN){ + if ((epoch == 0) && (!second)) { + ++long_sent_docs; + } + contains_long_sentence = true; + break; + } + } + } + // If we have enough sentences and no long sentences. + if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) { + + // Set values. + auto seq_len = int32_t{0}; + auto num_sent = int32_t{0}; + + // Loop through sentences. + for (auto sent_index=sent_index_first; + sent_index < sent_index_last; ++sent_index) { + + // Add the size and number of sentences. + seq_len += sizes[sent_index]; + ++num_sent; + --num_remain_sent; + + // If we have reached the target length. + // and there are an acceptable number of sentences left + // and if we have at least the minimum number of sentences. + // or if we have reached end of the document. + if (((seq_len >= target_seq_len) && + (num_remain_sent >= min_num_sent) && + (num_sent >= min_num_sent) ) || (num_remain_sent == 0)) { + + // Populate the map. + if (second) { + const auto map_index_0 = 4 * map_index; + // Each sample has 4 items: the starting sentence index, ending sentence index, + // the index of the document from which the block comes (used for fetching titles) + // and the unique id of the block (used for creating block indexes) + + maps[map_index_0] = static_cast(prev_start_index); + maps[map_index_0 + 1] = static_cast(sent_index + 1); + maps[map_index_0 + 2] = static_cast(doc); + maps[map_index_0 + 3] = static_cast(block_id); + } + + // Update indices / counters. + ++map_index; + ++block_id; + prev_start_index = sent_index + 1; + seq_len = 0; + num_sent = 0; + } + } // for (auto sent_index=sent_index_first; ... + } // if (num_remain_sent > 1) { + } // for (int doc=0; doc < num_docs; ++doc) { + } // for (int epoch=0; epoch < num_epochs; ++epoch) { + + if (!second) { + if (verbose) { + cout << " number of empty documents: " << empty_docs << + endl << std::flush; + cout << " number of documents with one sentence: " << + one_sent_docs << endl << std::flush; + cout << " number of documents with long sentences: " << + long_sent_docs << endl << std::flush; + cout << " will create mapping for " << map_index << + " samples" << endl << std::flush; + } + assert(maps == NULL); + assert(num_samples < 0); + maps = new DocIdx[4*map_index]; + num_samples = static_cast(map_index); + } + + } // for (int iteration=0; iteration < 2; ++iteration) { + + // Shuffle. + // We need a 64 bit random number generator as we might have more + // than 2 billion samples. + std::mt19937_64 rand64_gen(seed + 1); + for (auto i=(num_samples - 1); i > 0; --i) { + const auto j = static_cast(rand64_gen() % (i + 1)); + const auto i0 = 4 * i; + const auto j0 = 4 * j; + // Swap values. + swap(maps[i0], maps[j0]); + swap(maps[i0 + 1], maps[j0 + 1]); + swap(maps[i0 + 2], maps[j0 + 2]); + swap(maps[i0 + 3], maps[j0 + 3]); + } + + // Method to deallocate memory. + py::capsule free_when_done(maps, [](void *mem_) { + DocIdx *mem = reinterpret_cast(mem_); + delete[] mem; + }); + + // Return the numpy array. + const auto byte_size = sizeof(DocIdx); + return py::array(std::vector{num_samples, 4}, // shape + {4*byte_size, byte_size}, // C-style contiguous strides + maps, // the data pointer + free_when_done); // numpy array references + +} + +py::array build_blocks_mapping(const py::array_t& docs_, + const py::array_t& sizes_, + const py::array_t& titles_sizes_, + const int num_epochs, + const uint64_t max_num_samples, + const int max_seq_length, + const int seed, + const bool verbose, + const bool use_one_sent_blocks) { + + if (sizes_.size() > std::numeric_limits::max()) { + if (verbose) { + cout << " using uint64 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } else { + if (verbose) { + cout << " using uint32 for data mapping..." << endl << std::flush; + } + return build_blocks_mapping_impl(docs_, sizes_, titles_sizes_, + num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks); + } +} + +PYBIND11_MODULE(helpers, m) { + m.def("build_mapping", &build_mapping); + m.def("build_blocks_mapping", &build_blocks_mapping); + m.def("build_sample_idx", &build_sample_idx); + m.def("build_blending_indices", &build_blending_indices); +} diff --git a/internlm/data/mocked/__init__.py b/internlm/data/mocked/__init__.py new file mode 100644 index 000000000..b57361252 --- /dev/null +++ b/internlm/data/mocked/__init__.py @@ -0,0 +1,7 @@ +from .batch_sampler import MockedSequentialBatchSampler +from .dataset import MockedDataset + +__all__ = [ + "MockedSequentialBatchSampler", + "MockedDataset", +] diff --git a/internlm/data/mocked/batch_sampler.py b/internlm/data/mocked/batch_sampler.py new file mode 100644 index 000000000..de768386d --- /dev/null +++ b/internlm/data/mocked/batch_sampler.py @@ -0,0 +1,24 @@ +import copy + + +class MockedSequentialBatchSampler: + """ + MockedSequentialBatchSampler + """ + + def __init__(self, data_source, micro_num): + self.data_source = data_source + self.micro_num = micro_num + + def __iter__(self): + num_samples = len(self.data_source) + for start in range(0, num_samples, self.micro_num): + end = min(start + self.micro_num, num_samples) + yield list(range(start, end)) + + def __len__(self): + return (len(self.data_source) + self.micro_num - 1) // self.micro_num + + # TODO: implement copy method that compatible with InternEvo trainstate + def copy(self): + return copy.deepcopy(self) diff --git a/internlm/data/mocked/dataset.py b/internlm/data/mocked/dataset.py new file mode 100644 index 000000000..a7e9df717 --- /dev/null +++ b/internlm/data/mocked/dataset.py @@ -0,0 +1,91 @@ +import glob + +import torch +from torch.utils.data import Dataset + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc + + +def merge_tensors(file_pattern): + files = sorted(glob.glob(file_pattern)) + tensors = [] + for file in files: + tensor = torch.load(file) + tensors.append(tensor) + merged_tensor = torch.cat(tensors, dim=0) + return merged_tensor + + +def process_raw_data(raw_data, micro_bsz): + num_groups = len(raw_data) // micro_bsz + result = [] + for i in range(num_groups): + start_idx = i * micro_bsz + end_idx = start_idx + micro_bsz + group = raw_data[start_idx:end_idx] + concatenated = torch.cat(group, dim=0) + result.append(concatenated) + return result + + +class MockedDataset(Dataset): + """ + MockedDataset + """ + + def __init__(self, data_dir, micro_bsz, seq_len, mocked_steps): + db_input_ids = [] + db_labels = [] + + # load all saved data + for i in range(mocked_steps): + # define load pattern + input_ids_pattern = data_dir + f"_tokens_step{i+1}_dp*" + labels_pattern = data_dir + f"_labels_step{i+1}_dp*" + # merge input_ids, labels, and then chunk across dp + input_ids = torch.chunk(merge_tensors(input_ids_pattern), gpc.get_world_size(ParallelMode.DATA))[ + gpc.get_local_rank(ParallelMode.DATA) + ] + labels = torch.chunk(merge_tensors(labels_pattern), gpc.get_world_size(ParallelMode.DATA))[ + gpc.get_local_rank(ParallelMode.DATA) + ] + # load one step + db_input_ids.append(input_ids) + db_labels.append(labels) + + # transform db + db_input_ids = torch.concat(db_input_ids, dim=0) + db_labels = torch.concat(db_labels, dim=0) + db_input_ids = [db_input_ids[i] for i in range(db_input_ids.size(0))] + db_labels = [db_labels[i] for i in range(db_labels.size(0))] + + # gen data for internevo format + db_input_ids = process_raw_data(db_input_ids, micro_bsz) + db_labels = process_raw_data(db_labels, micro_bsz) + self.db_input_ids = [item.tolist() for item in db_input_ids] + self.db_labels = [item.tolist() for item in db_labels] + + assert len(self.db_input_ids) == len(self.db_labels) + self.dataset_len = len(self.db_input_ids) + self.micro_bsz = micro_bsz + self.seq_len = seq_len + + def __len__(self): + return self.dataset_len + + def __getitem__(self, idx): + tokens = self.db_input_ids[idx] + cu_seqlens = list(range(self.micro_bsz + 1)) + cu_seqlens = [i * self.seq_len for i in cu_seqlens] + indexes = list(range(self.seq_len)) * self.micro_bsz + labels = self.db_labels[idx] + type_ids = [0] * self.micro_bsz * self.seq_len + + return { + "tokens": tokens, + "cu_seqlens": cu_seqlens, + "indexes": indexes, + "labels": labels, + "type_ids": type_ids, + } diff --git a/internlm/data/train_state.py b/internlm/data/train_state.py index aa061b822..2c678b05b 100644 --- a/internlm/data/train_state.py +++ b/internlm/data/train_state.py @@ -6,7 +6,12 @@ def get_train_state(dataloader): # initialize and resume train state - if gpc.config.data.type in [DataType.tokenized.name, DataType.streaming.name]: + if gpc.config.data.type in [ + DataType.tokenized.name, + DataType.streaming.name, + DataType.megatron.name, + DataType.mocked.name, + ]: train_state = TrainState(gpc.config, dataloader.batch_sampler) else: raise ValueError(f"dataset type {gpc.config.data.type} is not supported") diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py index 03dee6df3..682aef199 100644 --- a/internlm/utils/utils.py +++ b/internlm/utils/utils.py @@ -52,6 +52,8 @@ class ModelType(Enum): class DataType(Enum): streaming = 1 tokenized = 2 + megatron = 3 + mocked = 4 class TensorParallelMode(Enum):