@@ -47,22 +47,24 @@ def _tokenize(self, samples):
47
47
texts = [sample ["text" ] for sample in samples ]
48
48
tokenized_outputs = self .tokenizer (texts , truncation = True )
49
49
for i in range (len (samples )):
50
- yield {key : tokenized_outputs [key ][i ] for key in tokenized_outputs }
50
+ assert "input_ids" in tokenized_outputs , "huggingface tokenizer should generate input_ids"
51
+ if len (tokenized_outputs ["input_ids" ][i ]) > 0 :
52
+ yield {key : tokenized_outputs [key ][i ] for key in tokenized_outputs }
51
53
52
54
def __getitem__ (self , _ ):
53
55
return next (self .senior_iterator )
54
56
55
57
56
58
class HuggingFacePackedDataset (Dataset ):
57
59
"""
58
- Simple packed dataset for huggingface.
60
+ Simple packed dataset for huggingface
59
61
"""
60
62
61
- def __init__ (self , dataset , seq_len , micro_bsz ):
63
+ def __init__ (self , dataset , seq_len , micro_bsz , pad_token_id = 0 ):
62
64
self .dataset = dataset
63
65
self .seq_len = seq_len
64
66
self .micro_bsz = micro_bsz
65
-
67
+ self . pad_token_id = pad_token_id
66
68
self .senior_iterator = iter (self )
67
69
68
70
def __iter__ (self ):
@@ -72,7 +74,7 @@ def __iter__(self):
72
74
for sample in self .dataset :
73
75
if len (input_ids + sample ["input_ids" ]) > self .micro_bsz * self .seq_len :
74
76
assert cu_seqlens [- 1 ] <= self .micro_bsz * self .seq_len
75
- input_ids = input_ids + [0 ] * (self .micro_bsz * self .seq_len - len (input_ids ))
77
+ input_ids = input_ids + [self . pad_token_id ] * (self .micro_bsz * self .seq_len - len (input_ids ))
76
78
cu_seqlens = (
77
79
cu_seqlens + [self .micro_bsz * self .seq_len ]
78
80
if cu_seqlens [- 1 ] < self .micro_bsz * self .seq_len
@@ -89,14 +91,15 @@ def __iter__(self):
89
91
}
90
92
input_ids = sample ["input_ids" ]
91
93
cu_seqlens = [0 , len (sample ["input_ids" ])]
92
- labels = sample ["input_ids" ][1 :] + [- 100 ]
94
+ labels = [ w if w > 0 else - 100 for w in sample ["input_ids" ] ][1 :] + [- 100 ]
93
95
else :
94
96
input_ids = input_ids + sample ["input_ids" ]
95
97
cu_seqlens .append (len (sample ["input_ids" ]) + cu_seqlens [- 1 ])
96
- labels = labels + sample ["input_ids" ][1 :] + [- 100 ]
98
+ labels = labels + [w if w > 0 else - 100 for w in sample ["input_ids" ]][1 :] + [- 100 ]
99
+
97
100
if input_ids :
98
101
assert cu_seqlens [- 1 ] <= self .micro_bsz * self .seq_len
99
- input_ids = input_ids + [0 ] * (self .micro_bsz * self .seq_len - len (input_ids ))
102
+ input_ids = input_ids + [self . pad_token_id ] * (self .micro_bsz * self .seq_len - len (input_ids ))
100
103
cu_seqlens = (
101
104
cu_seqlens + [self .micro_bsz * self .seq_len ]
102
105
if cu_seqlens [- 1 ] < self .micro_bsz * self .seq_len
0 commit comments