Skip to content

Commit 9df7bbb

Browse files
committed
Generalize HF datasets to a collection of HF dataasets via datasets, adds support for custom chat HF datasets (ml-explore#1088), and fixes (ml-explore#1087)
1 parent 331148d commit 9df7bbb

File tree

1 file changed

+117
-20
lines changed

1 file changed

+117
-20
lines changed

llms/mlx_lm/tuner/datasets.py

Lines changed: 117 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from pathlib import Path
3-
from typing import Dict, List
3+
from typing import Dict, List, Union
44

55
from transformers import PreTrainedTokenizer
66

@@ -29,12 +29,18 @@ class ChatDataset(Dataset):
2929
https://platform.openai.com/docs/guides/fine-tuning/example-format
3030
"""
3131

32-
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
32+
def __init__(
33+
self,
34+
data: List[Dict[str, str]],
35+
tokenizer: PreTrainedTokenizer,
36+
chat_key: str = "messages",
37+
):
3338
super().__init__(data)
3439
self._tokenizer = tokenizer
40+
self._chat_key = chat_key
3541

3642
def __getitem__(self, idx: int):
37-
messages = self._data[idx]["messages"]
43+
messages = self._data[idx][self._chat_key]
3844
text = self._tokenizer.apply_chat_template(
3945
messages,
4046
tools=self._data[idx].get("tools", None),
@@ -76,6 +82,29 @@ def __getitem__(self, idx: int):
7682
return text
7783

7884

85+
class CompletionsDatasetCollection:
86+
def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]):
87+
self.collection = data
88+
89+
def __getitem__(self, idx: int):
90+
item = next(self.collection)
91+
92+
curr_idx = idx
93+
94+
while True:
95+
try:
96+
if (curr_idx + 1) < len(item):
97+
return item[curr_idx]
98+
else:
99+
curr_idx -= len(item)
100+
item = next(self.collection)
101+
except StopIteration:
102+
raise IndexError(idx)
103+
104+
def __len__(self):
105+
return sum(map(len, self.collection))
106+
107+
79108
def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
80109
sample = data[0]
81110

@@ -127,40 +156,108 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
127156
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
128157
import datasets
129158

130-
hf_args = args.hf_dataset
131-
dataset_name = hf_args["name"]
132-
print(f"Loading Hugging Face dataset {dataset_name}.")
133-
text_feature = hf_args.get("text_feature")
134-
prompt_feature = hf_args.get("prompt_feature")
135-
completion_feature = hf_args.get("completion_feature")
136-
137-
def create_hf_dataset(split: str = None):
159+
def create_hf_dataset(
160+
dataset_name: Union[None, str],
161+
text_feature: Union[None, str],
162+
prompt_feature: Union[None, str],
163+
completion_feature: Union[None, str],
164+
chat_feature: Union[None, str],
165+
split: str = None,
166+
):
138167
ds = datasets.load_dataset(
139168
dataset_name,
140169
split=split,
141170
**hf_args.get("config", {}),
142171
)
143172
if prompt_feature and completion_feature:
144173
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
174+
elif chat_feature:
175+
return ChatDataset(ds, tokenizer, chat_key=chat_feature)
145176
elif text_feature:
146-
return Dataset(train_ds, text_key=text_feature)
177+
return Dataset(ds, text_key=text_feature)
147178
else:
148179
raise ValueError(
149180
"Specify either a prompt and completion feature or a text "
150181
"feature for the Hugging Face dataset."
151182
)
152183

153-
if args.train:
184+
def get_hf_custom_features(hf_args):
185+
return (
186+
hf_args.get("text_feature"),
187+
hf_args.get("prompt_feature"),
188+
hf_args.get("completion_feature"),
189+
hf_args.get("chat_feature"),
190+
)
191+
192+
def get_train_and_valid_splits(hf_args, ds_name):
154193
train_split = hf_args.get("train_split", "train[:80%]")
155194
valid_split = hf_args.get("valid_split", "train[-10%:]")
156-
train = create_hf_dataset(split=train_split)
157-
valid = create_hf_dataset(split=valid_split)
158-
else:
159-
train, valid = [], []
160-
if args.test:
161-
test = create_hf_dataset(split=hf_args.get("test_split"))
195+
text_f, prompt_f, completion_f, chat_f = get_hf_custom_features(hf_args)
196+
train = create_hf_dataset(
197+
ds_name, text_f, prompt_f, completion_f, chat_f, split=train_split
198+
)
199+
valid = create_hf_dataset(
200+
ds_name, text_f, prompt_f, completion_f, chat_f, split=valid_split
201+
)
202+
return train, valid
203+
204+
if args.datasets:
205+
dataset_collection = args.hf_datasets
206+
train_collection = []
207+
valid_collection = []
208+
test_collection = []
209+
for ds in dataset_collection:
210+
hf_args = ds["hf_dataset"]
211+
dataset_name = hf_args["name"]
212+
print(f"Loading Hugging Face dataset {dataset_name}.")
213+
text_feature, prompt_feature, completion_feature, chat_f = (
214+
get_hf_custom_features(hf_args)
215+
)
216+
if args.train:
217+
train, valid = get_train_and_valid_splits(hf_args, dataset_name)
218+
else:
219+
train, valid = [], []
220+
if args.test:
221+
test = create_hf_dataset(
222+
dataset_name,
223+
text_feature,
224+
prompt_feature,
225+
completion_feature,
226+
chat_f,
227+
split=hf_args.get("test_split"),
228+
)
229+
else:
230+
test = []
231+
train_collection.append(train)
232+
valid_collection.append(valid)
233+
test_collection.append(test)
234+
return (
235+
CompletionsDatasetCollection(train_collection),
236+
CompletionsDatasetCollection(valid_collection),
237+
CompletionsDatasetCollection(test_collection),
238+
)
162239
else:
163-
test = []
240+
hf_args = args.hf_dataset
241+
dataset_name = hf_args["name"]
242+
print(f"Loading Hugging Face dataset {dataset_name}.")
243+
text_feature, prompt_feature, completion_feature, chat_feature = (
244+
get_hf_custom_features(hf_args)
245+
)
246+
if args.train:
247+
train, valid = get_train_and_valid_splits(hf_args, dataset_name)
248+
else:
249+
train, valid = [], []
250+
if args.test:
251+
test = create_hf_dataset(
252+
dataset_name,
253+
text_feature,
254+
prompt_feature,
255+
completion_feature,
256+
chat_feature,
257+
split=hf_args.get("test_split"),
258+
)
259+
else:
260+
test = []
164261

165262
return train, valid, test
166263

0 commit comments

Comments
 (0)