|
1 | 1 | import json
|
2 | 2 | from pathlib import Path
|
3 |
| -from typing import Dict, List |
| 3 | +from typing import Dict, List, Union |
4 | 4 |
|
5 | 5 | from transformers import PreTrainedTokenizer
|
6 | 6 |
|
@@ -29,12 +29,18 @@ class ChatDataset(Dataset):
|
29 | 29 | https://platform.openai.com/docs/guides/fine-tuning/example-format
|
30 | 30 | """
|
31 | 31 |
|
32 |
| - def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + data: List[Dict[str, str]], |
| 35 | + tokenizer: PreTrainedTokenizer, |
| 36 | + chat_key: str = "messages", |
| 37 | + ): |
33 | 38 | super().__init__(data)
|
34 | 39 | self._tokenizer = tokenizer
|
| 40 | + self._chat_key = chat_key |
35 | 41 |
|
36 | 42 | def __getitem__(self, idx: int):
|
37 |
| - messages = self._data[idx]["messages"] |
| 43 | + messages = self._data[idx][self._chat_key] |
38 | 44 | text = self._tokenizer.apply_chat_template(
|
39 | 45 | messages,
|
40 | 46 | tools=self._data[idx].get("tools", None),
|
@@ -76,6 +82,29 @@ def __getitem__(self, idx: int):
|
76 | 82 | return text
|
77 | 83 |
|
78 | 84 |
|
| 85 | +class CompletionsDatasetCollection: |
| 86 | + def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): |
| 87 | + self.collection = data |
| 88 | + |
| 89 | + def __getitem__(self, idx: int): |
| 90 | + item = next(self.collection) |
| 91 | + |
| 92 | + curr_idx = idx |
| 93 | + |
| 94 | + while True: |
| 95 | + try: |
| 96 | + if (curr_idx + 1) < len(item): |
| 97 | + return item[curr_idx] |
| 98 | + else: |
| 99 | + curr_idx -= len(item) |
| 100 | + item = next(self.collection) |
| 101 | + except StopIteration: |
| 102 | + raise IndexError(idx) |
| 103 | + |
| 104 | + def __len__(self): |
| 105 | + return sum(map(len, self.collection)) |
| 106 | + |
| 107 | + |
79 | 108 | def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
|
80 | 109 | sample = data[0]
|
81 | 110 |
|
@@ -127,40 +156,108 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
|
127 | 156 | def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
|
128 | 157 | import datasets
|
129 | 158 |
|
130 |
| - hf_args = args.hf_dataset |
131 |
| - dataset_name = hf_args["name"] |
132 |
| - print(f"Loading Hugging Face dataset {dataset_name}.") |
133 |
| - text_feature = hf_args.get("text_feature") |
134 |
| - prompt_feature = hf_args.get("prompt_feature") |
135 |
| - completion_feature = hf_args.get("completion_feature") |
136 |
| - |
137 |
| - def create_hf_dataset(split: str = None): |
| 159 | + def create_hf_dataset( |
| 160 | + dataset_name: Union[None, str], |
| 161 | + text_feature: Union[None, str], |
| 162 | + prompt_feature: Union[None, str], |
| 163 | + completion_feature: Union[None, str], |
| 164 | + chat_feature: Union[None, str], |
| 165 | + split: str = None, |
| 166 | + ): |
138 | 167 | ds = datasets.load_dataset(
|
139 | 168 | dataset_name,
|
140 | 169 | split=split,
|
141 | 170 | **hf_args.get("config", {}),
|
142 | 171 | )
|
143 | 172 | if prompt_feature and completion_feature:
|
144 | 173 | return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
|
| 174 | + elif chat_feature: |
| 175 | + return ChatDataset(ds, tokenizer, chat_key=chat_feature) |
145 | 176 | elif text_feature:
|
146 |
| - return Dataset(train_ds, text_key=text_feature) |
| 177 | + return Dataset(ds, text_key=text_feature) |
147 | 178 | else:
|
148 | 179 | raise ValueError(
|
149 | 180 | "Specify either a prompt and completion feature or a text "
|
150 | 181 | "feature for the Hugging Face dataset."
|
151 | 182 | )
|
152 | 183 |
|
153 |
| - if args.train: |
| 184 | + def get_hf_custom_features(hf_args): |
| 185 | + return ( |
| 186 | + hf_args.get("text_feature"), |
| 187 | + hf_args.get("prompt_feature"), |
| 188 | + hf_args.get("completion_feature"), |
| 189 | + hf_args.get("chat_feature"), |
| 190 | + ) |
| 191 | + |
| 192 | + def get_train_and_valid_splits(hf_args, ds_name): |
154 | 193 | train_split = hf_args.get("train_split", "train[:80%]")
|
155 | 194 | valid_split = hf_args.get("valid_split", "train[-10%:]")
|
156 |
| - train = create_hf_dataset(split=train_split) |
157 |
| - valid = create_hf_dataset(split=valid_split) |
158 |
| - else: |
159 |
| - train, valid = [], [] |
160 |
| - if args.test: |
161 |
| - test = create_hf_dataset(split=hf_args.get("test_split")) |
| 195 | + text_f, prompt_f, completion_f, chat_f = get_hf_custom_features(hf_args) |
| 196 | + train = create_hf_dataset( |
| 197 | + ds_name, text_f, prompt_f, completion_f, chat_f, split=train_split |
| 198 | + ) |
| 199 | + valid = create_hf_dataset( |
| 200 | + ds_name, text_f, prompt_f, completion_f, chat_f, split=valid_split |
| 201 | + ) |
| 202 | + return train, valid |
| 203 | + |
| 204 | + if args.datasets: |
| 205 | + dataset_collection = args.hf_datasets |
| 206 | + train_collection = [] |
| 207 | + valid_collection = [] |
| 208 | + test_collection = [] |
| 209 | + for ds in dataset_collection: |
| 210 | + hf_args = ds["hf_dataset"] |
| 211 | + dataset_name = hf_args["name"] |
| 212 | + print(f"Loading Hugging Face dataset {dataset_name}.") |
| 213 | + text_feature, prompt_feature, completion_feature, chat_f = ( |
| 214 | + get_hf_custom_features(hf_args) |
| 215 | + ) |
| 216 | + if args.train: |
| 217 | + train, valid = get_train_and_valid_splits(hf_args, dataset_name) |
| 218 | + else: |
| 219 | + train, valid = [], [] |
| 220 | + if args.test: |
| 221 | + test = create_hf_dataset( |
| 222 | + dataset_name, |
| 223 | + text_feature, |
| 224 | + prompt_feature, |
| 225 | + completion_feature, |
| 226 | + chat_f, |
| 227 | + split=hf_args.get("test_split"), |
| 228 | + ) |
| 229 | + else: |
| 230 | + test = [] |
| 231 | + train_collection.append(train) |
| 232 | + valid_collection.append(valid) |
| 233 | + test_collection.append(test) |
| 234 | + return ( |
| 235 | + CompletionsDatasetCollection(train_collection), |
| 236 | + CompletionsDatasetCollection(valid_collection), |
| 237 | + CompletionsDatasetCollection(test_collection), |
| 238 | + ) |
162 | 239 | else:
|
163 |
| - test = [] |
| 240 | + hf_args = args.hf_dataset |
| 241 | + dataset_name = hf_args["name"] |
| 242 | + print(f"Loading Hugging Face dataset {dataset_name}.") |
| 243 | + text_feature, prompt_feature, completion_feature, chat_feature = ( |
| 244 | + get_hf_custom_features(hf_args) |
| 245 | + ) |
| 246 | + if args.train: |
| 247 | + train, valid = get_train_and_valid_splits(hf_args, dataset_name) |
| 248 | + else: |
| 249 | + train, valid = [], [] |
| 250 | + if args.test: |
| 251 | + test = create_hf_dataset( |
| 252 | + dataset_name, |
| 253 | + text_feature, |
| 254 | + prompt_feature, |
| 255 | + completion_feature, |
| 256 | + chat_feature, |
| 257 | + split=hf_args.get("test_split"), |
| 258 | + ) |
| 259 | + else: |
| 260 | + test = [] |
164 | 261 |
|
165 | 262 | return train, valid, test
|
166 | 263 |
|
|
0 commit comments