Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 20 additions & 24 deletions torchtune/datasets/_text_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Union

from datasets import load_dataset
from torch.utils.data import Dataset
Expand All @@ -29,9 +29,6 @@ class TextCompletionDataset(Dataset):
for Hugging Face datasets or tabular data. For local datasets with a single column
(e.g. unstructured txt files), use the default "text" which is used by Hugging Face datasets
when loaded into memory. Default is "text".
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
add_eos (bool): Whether to add an EOS token to the end of the sequence. Default is True.
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``,
such as ``data_files`` or ``split``.
Expand All @@ -42,13 +39,11 @@ def __init__(
tokenizer: ModelTokenizer,
source: str,
column: str = "text",
max_seq_len: Optional[int] = None,
add_eos: bool = True,
**load_dataset_kwargs: Dict[str, Any],
) -> None:
self._tokenizer = tokenizer
self._data = load_dataset(source, **load_dataset_kwargs)
self.max_seq_len = max_seq_len
self._column = column
self.add_eos = add_eos

Expand All @@ -64,8 +59,8 @@ def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, List[int]]:
tokens = self._tokenizer.encode(text=prompt, add_bos=True, add_eos=self.add_eos)

# Truncate if needed, but don't coerce EOS id
if self.max_seq_len is not None:
tokens = truncate(tokens, self.max_seq_len - 1)
if self._tokenizer.max_seq_len is not None:
tokens = truncate(tokens, self._tokenizer.max_seq_len - 1)

# No need to offset labels by 1 - happens in the recipe
labels = tokens.copy()
Expand All @@ -77,10 +72,10 @@ def text_completion_dataset(
tokenizer: ModelTokenizer,
source: str,
column: str = "text",
max_seq_len: Optional[int] = None,
add_eos: bool = True,
packed: bool = False,
split_across_pack: bool = True,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
) -> Union[TextCompletionDataset, PackedDataset]:
"""
Expand All @@ -100,16 +95,15 @@ def text_completion_dataset(
for Hugging Face datasets or tabular data. For local datasets with a single column
(e.g. unstructured txt files), use the default "text" which is used by Hugging Face datasets
when loaded into memory. Default is "text".
max_seq_len (Optional[int]): Maximum number of tokens in the returned input and label token id lists.
Default is None, disabling truncation. We recommend setting this to the highest you can fit in memory
and is supported by the model. For example, llama2-7B supports up to 4096 for sequence length.
add_eos (bool): Whether to add an EOS token to the end of the sequence. Default is True.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False.
split_across_pack (bool): if the last sample in a pack does not fit in ``max_seq_len``,
split the sample into the next pack, or move it entirely to the beginning of the next pack.
For pre-training, typically this is set to True for general text completion. For
fine-tuning, typically this is set to False to avoid truncating sentences in instruct
tuning. This argument is ignored if ``packed=False``. Default is True.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.
Examples:
Expand All @@ -118,9 +112,9 @@ def text_completion_dataset(
... tokenizer=tokenizer,
... source="allenai/c4",
... column="text",
... max_seq_len=2096,
... data_dir="realnewslike",
... packed=False,
... split="train",
... )
This can also be accomplished via the yaml config::
Expand All @@ -129,29 +123,31 @@ def text_completion_dataset(
_component_: torchtune.datasets.text_completion_dataset
source: allenai/c4
column: text
max_seq_len: 2096
data_dir: realnewslike
packed: False
split: train
Returns:
Union[TextCompletionDataset, PackedDataset]: the configured :class:`~torchtune.datasets.TextCompletionDataset`
or :class:`~torchtune.datasets.PackedDataset` if ``packed=True``
Raises:
ValueError: If ``packed=True`` and ``tokenizer.max_seq_len`` is not set.
"""
ds = TextCompletionDataset(
tokenizer=tokenizer,
source=source,
column=column,
max_seq_len=max_seq_len,
add_eos=add_eos,
split=split,
**load_dataset_kwargs,
)
return (
PackedDataset(
ds,
max_seq_len=max_seq_len,
padding_idx=tokenizer.pad_id,
split_across_pack=split_across_pack,
if packed:
if tokenizer.max_seq_len is None:
raise ValueError(
"PackedDataset requires a max_seq_len to be set on the tokenizer."
)
return PackedDataset(
ds, max_seq_len=tokenizer.max_seq_len, split_across_pack=split_across_pack
)
if packed
else ds
)
return ds