Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internlm/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, config, batch_sampler) -> None:
self.lr = config.adam.lr

# smapler state
if batch_sampler:
if batch_sampler is not None:
self.init_batch_sampler(batch_sampler)

# tgs statistic
Expand Down
79 changes: 77 additions & 2 deletions internlm/data/build_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# Copyright (c) InternLM. All rights reserved.
import subprocess
from functools import partial

import torch.distributed as dist
from torch.utils.data import ConcatDataset, DataLoader

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.data.megatron.batch_sampler import MegatronBatchSampler
from internlm.data.megatron.collaters import megatron_collate_fn
from internlm.data.megatron.dataset import build_megatron_dataset
from internlm.data.mocked.batch_sampler import MockedSequentialBatchSampler
from internlm.data.mocked.dataset import MockedDataset
from internlm.data.streaming.batch_sampler import StreamingStaticBatchSampler
from internlm.data.streaming.collaters import streaming_packed_collate_fn
from internlm.data.streaming.dataset import (
Expand Down Expand Up @@ -139,6 +145,62 @@ def get_streaming_train_loader_items(data_cfg):
return train_ds, train_sampler, streaming_packed_collate_fn


def get_megatron_train_loader_items(data_cfg):
try:
from internlm.data.megatron import helpers # noqa # pylint: disable=W0611
except ImportError:
if gpc.is_rank_for_log():
subprocess.run( # noqa # pylint: disable=W1510
[
"g++",
"-O3",
"-shared",
"-std=c++11",
"-fPIC",
"-fdiagnostics-color",
"-I$(python3-config --includes)",
"-I$(python3 -m pybind11 --includes)",
"internlm/data/megatron/helpers.cpp",
"-o",
"internlm/data/megatron/helpers.so",
]
)
train_ds = build_megatron_dataset(
data_prefix=data_cfg.train_folder,
data_impl=data_cfg.get("data_impl", "infer"),
splits_string="1.0, 0.0, 0.0",
train_valid_test_num_samples=[9600000, 0, 0],
seq_len=data_cfg.seq_len,
seed=data_cfg.get("seed", 1024),
skip_warmup=True,
)

train_sampler = MegatronBatchSampler(
total_samples=len(train_ds),
consumed_samples=0,
batch_size=data_cfg.micro_num * data_cfg.micro_bsz,
drop_last=True,
)

train_collate_fn = partial(
megatron_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len
)

return train_ds, train_sampler, train_collate_fn


def get_mock_train_loader_items(data_cfg):
train_ds = MockedDataset(
data_dir=data_cfg.train_folder, # defined the path of mocked data
micro_bsz=data_cfg.micro_bsz,
seq_len=data_cfg.seq_len,
mocked_steps=data_cfg.mocked_steps, # defined the steps of mocked data
)
train_sampler = MockedSequentialBatchSampler(train_ds, data_cfg.micro_num)
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.seq_len * data_cfg.micro_bsz)
return train_ds, train_sampler, train_collate_fn


def build_train_loader_with_data_type():
"""
Build and return the training data loader based on data type.
Expand All @@ -155,6 +217,14 @@ def build_train_loader_with_data_type():
train_ds, train_sampler, train_collate_fn = get_streaming_train_loader_items(data_cfg)
# TODO: support more dataset_types
dataset_types = ["en"]
elif data_cfg.type == DataType.megatron.name:
train_ds, train_sampler, train_collate_fn = get_megatron_train_loader_items(data_cfg)
# TODO: support more dataset_types
dataset_types = ["en"]
elif data_cfg.type == DataType.mocked.name:
train_ds, train_sampler, train_collate_fn = get_mock_train_loader_items(data_cfg)
# TODO: support more dataset_types
dataset_types = ["en"]
else:
raise ValueError(f"dataset type {data_cfg.type} is not supported")

Expand All @@ -176,8 +246,13 @@ def build_valid_loader_with_data_type():

data_cfg = gpc.config.data

# TODO: support streaming dataset for validation
if data_cfg.type in [DataType.tokenized.name, DataType.streaming.name]:
# TODO: For validation, currenlt we only support dummy dataset for streaming/megatron/mocked DataType.
if data_cfg.type in [
DataType.tokenized.name,
DataType.streaming.name,
DataType.megatron.name,
DataType.mocked.name,
]:
valid_ds, valid_collate_fn = get_tokenized_valid_loader_items(data_cfg)
else:
raise ValueError(f"dataset type {data_cfg.type} is not supported")
Expand Down
9 changes: 9 additions & 0 deletions internlm/data/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .batch_sampler import MegatronBatchSampler
from .collaters import megatron_collate_fn
from .dataset import build_megatron_dataset

__all__ = [
"MegatronBatchSampler",
"build_megatron_dataset",
"megatron_collate_fn",
]
62 changes: 62 additions & 0 deletions internlm/data/megatron/batch_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import copy
import math

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc


class MegatronBatchSampler:
"""
MegatronBatchSampler
"""

def __init__(self, total_samples, consumed_samples, batch_size, drop_last=True):
# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.batch_size = batch_size
self.drop_last = drop_last

self.dp_rank = gpc.get_local_rank(ParallelMode.DATA)
self.dp_size = gpc.get_world_size(ParallelMode.DATA)

# Sanity checks.
assert self.total_samples > 0, "no sample to consume: {}".format(self.total_samples)
assert self.consumed_samples < self.total_samples, "no samples left to consume: {}, {}".format(
self.consumed_samples, self.total_samples
)
assert self.batch_size > 0
assert self.dp_size > 0
assert self.dp_rank < self.dp_size, "dp_rank should be smaller than dp_size: {}, " "{}".format(
self.dp_rank, self.dp_size
)

def __len__(self):
if self.drop_last and self.total_samples % self.dp_size != 0:
return math.ceil(self.total_samples - self.dp_size) / self.dp_size
else:
return math.ceil(self.total_samples / self.dp_size)

def get_start_end_idx(self):
start_idx = self.dp_rank * self.batch_size
end_idx = start_idx + self.batch_size
return start_idx, end_idx

def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.batch_size * self.dp_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []

# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]

# TODO: implement copy method that compatible with InternEvo trainstate
def copy(self):
return copy.deepcopy(self)
49 changes: 49 additions & 0 deletions internlm/data/megatron/collaters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch


def megatron_collate_fn(batch, micro_num, micro_bsz, seq_len):

input_ids_result = [[] for _ in range(micro_num)]
labels_result = [[] for _ in range(micro_num)]
cu_seqlens = []
cu_seqlens_list = []
indexes = []
indexes_list = []

for i, item in enumerate(batch):
assert i < micro_num * micro_bsz
seq_len_list = item["text"]
assert len(seq_len_list) == seq_len + 1

micro_bsz_index = i % micro_bsz
micro_num_index = i // micro_bsz

input_ids_result[micro_num_index].append(seq_len_list[:-1])
labels_result[micro_num_index].append(seq_len_list[1:])

cu_seqlens.append(seq_len * micro_bsz_index)
indexes = indexes + list(range(seq_len))

if micro_bsz_index == micro_bsz - 1:
input_ids_result[micro_num_index] = torch.cat(
[torch.from_numpy(arr).long() for arr in input_ids_result[micro_num_index]], dim=0
)
labels_result[micro_num_index] = torch.cat(
[torch.from_numpy(arr).long() for arr in labels_result[micro_num_index]], dim=0
)
cu_seqlens.append(seq_len * micro_bsz)
cu_seqlens_list.append(torch.IntTensor(cu_seqlens))
cu_seqlens = []
indexes_list.append(torch.IntTensor(indexes))
indexes = []

input_ids = torch.stack(input_ids_result)
labels = torch.stack(labels_result)
indexes = torch.stack(indexes_list)

return {
"input_ids": input_ids,
"cu_seqlens": cu_seqlens_list,
"indexes": indexes,
"type_ids": torch.zeros(micro_num, micro_bsz * seq_len, dtype=torch.int64),
}, labels
Loading
Loading