Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/torchtune/datasets/test_stack_exchange_paired_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ def test_dataset_get_item(self, mock_load_dataset, train_on_input):
# Check that the input is masked
assert sample["rejected_labels"].count(CROSS_ENTROPY_IGNORE_IDX) == 52

def test_dataset_fails_with_packed(self):
with pytest.raises(
ValueError, match="StackExchangePairedDataset does not support packing"
):
stack_exchange_paired_dataset(
tokenizer=DummyTokenizer(),
train_on_input=True,
packed=True,
)


class TestStackExchangePairedToMessages:
@pytest.fixture
Expand Down
9 changes: 9 additions & 0 deletions torchtune/datasets/_stack_exchange_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def stack_exchange_paired_dataset(
source: str = "lvwerra/stack-exchange-paired",
column_map: Optional[Dict[str, str]] = None,
train_on_input: bool = False,
packed: bool = False,
filter_fn: Optional[Callable] = None,
split: str = "train",
**load_dataset_kwargs: Dict[str, Any],
Expand All @@ -101,17 +102,25 @@ def stack_exchange_paired_dataset(
Keys should be "prompt", "chosen", and "rejected" and values should be the actual column names.
Default is None, keeping the default column names.
train_on_input (bool): Whether the model is trained on the prompt or not. Default is False.
packed (bool): Whether or not to pack the dataset to ``max_seq_len`` prior to training. Default is False and this
is currently not supported for this dataset. A ValueError will be raised if this is set to True.
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more
details.
split (str): ``split`` argument for ``datasets.load_dataset``. You can use this argument to load a subset
of a given split, e.g. ``split="train[:10%]"``. Default is "train".
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``.

Raises:
ValueError: If ``packed`` is True, they are not supported for preference datasets yet.

Returns:
PreferenceDataset: The preference dataset built from source paired data.
"""

if packed:
raise ValueError("StackExchangePairedDataset does not support packing.")

column_map = column_map or {
"prompt": "question",
"chosen": "response_j",
Expand Down