Skip to content

Commit 2c6df5c

Browse files
authored
fix(huggingface): fix huggingface dataloader when using some huggingface third-party tokenizers (#277)
1 parent 19d00ac commit 2c6df5c

File tree

3 files changed

+41
-22
lines changed

3 files changed

+41
-22
lines changed

internlm/data/build_dataloader.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,22 @@ def get_hf_train_loader_items(data_cfg):
125125
model_max_length=data_cfg.seq_len,
126126
subset_name=data_cfg.get("subset_name", None),
127127
)
128+
pad_token_id = gpc.config.model.get("pad_token_id", 0)
128129
if gpc.config.model_type == "hf" and not data_cfg.use_packed_dataset:
129130
train_sampler = StreamingStaticBatchSampler(
130131
batch_size=data_cfg.micro_num * data_cfg.micro_bsz, rampup_batch_size=data_cfg.rampup_batch_size
131132
)
132133
train_collate_fn = partial(
133-
nopack_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len
134+
nopack_collate_fn,
135+
micro_num=data_cfg.micro_num,
136+
micro_bsz=data_cfg.micro_bsz,
137+
seq_len=data_cfg.seq_len,
138+
pad_token_id=pad_token_id,
134139
)
135140
else:
136-
train_ds = HuggingFacePackedDataset(dataset=train_ds, seq_len=data_cfg.seq_len, micro_bsz=data_cfg.micro_bsz)
141+
train_ds = HuggingFacePackedDataset(
142+
dataset=train_ds, seq_len=data_cfg.seq_len, micro_bsz=data_cfg.micro_bsz, pad_token_id=pad_token_id
143+
)
137144
train_sampler = StreamingStaticBatchSampler(
138145
batch_size=data_cfg.micro_num, rampup_batch_size=data_cfg.rampup_batch_size
139146
)

internlm/data/streaming/collaters.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,34 @@
11
import torch
22

33

4-
def nopack_collate_fn(batch, micro_num, micro_bsz, seq_len):
4+
def nopack_collate_fn(batch, micro_num, micro_bsz, seq_len, pad_token_id=0):
55
input_ids_list = []
66
attention_mask_list = []
77
labels_list = []
8+
89
for b in batch:
9-
attention_mask = torch.tensor(b["attention_mask"])
10-
input_ids = torch.LongTensor(b["input_ids"])
11-
input_ids = torch.abs(input_ids * attention_mask)
12-
input_ids = torch.nn.functional.pad(input_ids, (0, seq_len - len(input_ids)), mode="constant", value=0)
13-
attention_mask = torch.nn.functional.pad(
14-
attention_mask, (0, seq_len - len(attention_mask)), mode="constant", value=0
15-
)
16-
label = torch.LongTensor([w if w > 0 else -100 for w in input_ids.tolist()][1:] + [-100])
17-
input_ids_list.append(input_ids)
18-
attention_mask_list.append(attention_mask)
19-
labels_list.append(label)
10+
assert len(b["input_ids"]) > 0
11+
12+
if "attention_mask" in b:
13+
assert len(b["input_ids"]) == len(
14+
b["attention_mask"]
15+
), "input_ids and attention_mask should be equal length"
16+
else:
17+
b["attention_mask"] = [True] * len(b["input_ids"])
18+
19+
input_ids = b["input_ids"] + [pad_token_id] * (seq_len - len(b["input_ids"]))
20+
attention_mask = b["attention_mask"] + [False] * (seq_len - len(b["attention_mask"]))
21+
labels = [w if w > 0 else -100 for w in b["input_ids"]][1:] + [-100]
22+
labels = labels + [-100] * (seq_len - len(b["input_ids"]))
23+
24+
input_ids_list.append(torch.LongTensor(input_ids))
25+
attention_mask_list.append(torch.BoolTensor(attention_mask))
26+
labels_list.append(torch.LongTensor(labels))
27+
2028
input_ids = torch.stack(input_ids_list)
2129
attention_mask = torch.stack(attention_mask_list)
2230
labels = torch.stack(labels_list)
31+
2332
return {
2433
"input_ids": input_ids,
2534
"attention_mask": attention_mask,

internlm/data/streaming/dataset.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,24 @@ def _tokenize(self, samples):
4747
texts = [sample["text"] for sample in samples]
4848
tokenized_outputs = self.tokenizer(texts, truncation=True)
4949
for i in range(len(samples)):
50-
yield {key: tokenized_outputs[key][i] for key in tokenized_outputs}
50+
assert "input_ids" in tokenized_outputs, "huggingface tokenizer should generate input_ids"
51+
if len(tokenized_outputs["input_ids"][i]) > 0:
52+
yield {key: tokenized_outputs[key][i] for key in tokenized_outputs}
5153

5254
def __getitem__(self, _):
5355
return next(self.senior_iterator)
5456

5557

5658
class HuggingFacePackedDataset(Dataset):
5759
"""
58-
Simple packed dataset for huggingface.
60+
Simple packed dataset for huggingface
5961
"""
6062

61-
def __init__(self, dataset, seq_len, micro_bsz):
63+
def __init__(self, dataset, seq_len, micro_bsz, pad_token_id=0):
6264
self.dataset = dataset
6365
self.seq_len = seq_len
6466
self.micro_bsz = micro_bsz
65-
67+
self.pad_token_id = pad_token_id
6668
self.senior_iterator = iter(self)
6769

6870
def __iter__(self):
@@ -72,7 +74,7 @@ def __iter__(self):
7274
for sample in self.dataset:
7375
if len(input_ids + sample["input_ids"]) > self.micro_bsz * self.seq_len:
7476
assert cu_seqlens[-1] <= self.micro_bsz * self.seq_len
75-
input_ids = input_ids + [0] * (self.micro_bsz * self.seq_len - len(input_ids))
77+
input_ids = input_ids + [self.pad_token_id] * (self.micro_bsz * self.seq_len - len(input_ids))
7678
cu_seqlens = (
7779
cu_seqlens + [self.micro_bsz * self.seq_len]
7880
if cu_seqlens[-1] < self.micro_bsz * self.seq_len
@@ -89,14 +91,15 @@ def __iter__(self):
8991
}
9092
input_ids = sample["input_ids"]
9193
cu_seqlens = [0, len(sample["input_ids"])]
92-
labels = sample["input_ids"][1:] + [-100]
94+
labels = [w if w > 0 else -100 for w in sample["input_ids"]][1:] + [-100]
9395
else:
9496
input_ids = input_ids + sample["input_ids"]
9597
cu_seqlens.append(len(sample["input_ids"]) + cu_seqlens[-1])
96-
labels = labels + sample["input_ids"][1:] + [-100]
98+
labels = labels + [w if w > 0 else -100 for w in sample["input_ids"]][1:] + [-100]
99+
97100
if input_ids:
98101
assert cu_seqlens[-1] <= self.micro_bsz * self.seq_len
99-
input_ids = input_ids + [0] * (self.micro_bsz * self.seq_len - len(input_ids))
102+
input_ids = input_ids + [self.pad_token_id] * (self.micro_bsz * self.seq_len - len(input_ids))
100103
cu_seqlens = (
101104
cu_seqlens + [self.micro_bsz * self.seq_len]
102105
if cu_seqlens[-1] < self.micro_bsz * self.seq_len

0 commit comments

Comments
 (0)