Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ Collaters used to collect samples into batches and handle any padding.
:nosignatures:

padded_collate
padded_collate_sft
padded_collate_dpo
left_pad_sequence

Helper functions
----------------
Expand Down
2 changes: 0 additions & 2 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ Components for RLHF algorithms like PPO.
rlhf.estimate_advantages
rlhf.get_rewards_ppo
rlhf.truncate_sequence_at_first_stop_token
rlhf.left_padded_collate
rlhf.padded_collate_dpo

Losses
^^^^^^
Expand Down
4 changes: 2 additions & 2 deletions recipes/dev/lora_finetune_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate
from torchtune.data import padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
Expand Down Expand Up @@ -468,7 +468,7 @@ def _setup_data(
sampler=sampler,
collate_fn=(
partial(
padded_collate,
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
Expand Down
9 changes: 3 additions & 6 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
from omegaconf import DictConfig

from torch import nn
from torch.nn.utils.rnn import pad_sequence

from torchtune import config, training, utils
from torchtune.data import left_pad_sequence
from torchtune.modules import TransformerDecoder
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.recipe_interfaces import EvalRecipeInterface
Expand Down Expand Up @@ -112,15 +111,13 @@ def tok_batch_encode(
tokenized_text = [self.tok_encode(x) for x in text]

# pad left
x = pad_sequence(
x = left_pad_sequence(
[
torch.tensor(x[::-1]) for x in tokenized_text
], # first flip each sequence and pad
batch_first=True,
padding_value=self._tokenizer.pad_id,
).flip(
dims=[1]
) # flip back to correct order
)

return x, torch.ones_like(x) # return 'mask' b/c it's expected by the harness

Expand Down
18 changes: 10 additions & 8 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate
from torchtune.data import padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.utils import DummyProfiler, PROFILER_KEY
Expand Down Expand Up @@ -495,13 +495,15 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=partial(
padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
),
)

if self._is_rank_zero:
Expand Down
18 changes: 10 additions & 8 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.utils.data import DataLoader, DistributedSampler

from torchtune import config, modules, training, utils
from torchtune.data import padded_collate
from torchtune.data import padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.utils import DummyProfiler, PROFILER_KEY
Expand Down Expand Up @@ -457,13 +457,15 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=partial(
padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
),
)

log.info("Dataset and Sampler are initialized.")
Expand Down
4 changes: 2 additions & 2 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.peft import (
Expand Down Expand Up @@ -457,7 +457,7 @@ def _setup_data(
batch_size=batch_size,
sampler=sampler,
collate_fn=partial(
rlhf.padded_collate_dpo,
padded_collate_dpo,
padding_idx=self._tokenizer.pad_id,
ignore_idx=CROSS_ENTROPY_IGNORE_IDX,
),
Expand Down
4 changes: 2 additions & 2 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX, padded_collate_dpo
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.peft import (
Expand Down Expand Up @@ -375,7 +375,7 @@ def _setup_data(
sampler=sampler,
batch_size=batch_size,
collate_fn=partial(
rlhf.padded_collate_dpo,
padded_collate_dpo,
padding_idx=self._tokenizer.pad_id,
ignore_idx=CROSS_ENTROPY_IGNORE_IDX,
),
Expand Down
4 changes: 2 additions & 2 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate
from torchtune.data import padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
Expand Down Expand Up @@ -553,7 +553,7 @@ def _setup_data(
sampler=sampler,
collate_fn=(
partial(
padded_collate,
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
Expand Down
4 changes: 2 additions & 2 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate
from torchtune.data import padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
get_adapter_params,
Expand Down Expand Up @@ -486,7 +486,7 @@ def _setup_data(
batch_size=batch_size,
collate_fn=(
partial(
padded_collate,
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
Expand Down
7 changes: 5 additions & 2 deletions recipes/ppo_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import left_pad_sequence, padded_collate
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.rlhf import PPOStats, Trajectory
Expand Down Expand Up @@ -581,7 +582,9 @@ def _setup_data(
sampler=sampler,
batch_size=batch_size,
collate_fn=partial(
rlhf.left_padded_collate,
padded_collate,
pad_fn=left_pad_sequence,
keys_to_pad=["tokens"],
padding_idx=self._tokenizer.pad_id,
),
drop_last=True,
Expand Down Expand Up @@ -829,7 +832,7 @@ def train(self) -> None:
self._sampler.set_epoch(curr_epoch)

for _, batch in enumerate(self._dataloader):
batch = batch.to(self._device)
batch = batch["tokens"].to(self._device)
_, context_length = batch.shape

# step 1. generate the trajectory using:
Expand Down
18 changes: 10 additions & 8 deletions recipes/qat_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate
from torchtune.data import padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.utils import DummyProfiler, PROFILER_KEY
Expand Down Expand Up @@ -523,13 +523,15 @@ def _setup_data(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=partial(
padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None
),
)

if self._is_rank_zero:
Expand Down
Loading