Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: 6306a48f7dae5861702d573c9c247e4e9498e867
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-ast
Expand All @@ -18,7 +18,7 @@ repos:
exclude: '^(.*\.svg)$'

- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.4
rev: v1.5.5
hooks:
- id: insert-license
files: \.py$|\.sh$
Expand All @@ -27,7 +27,7 @@ repos:
- docs/license_header.txt

- repo: https://github.com/pycqa/flake8
rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b
rev: 7.1.1
hooks:
- id: flake8
additional_dependencies:
Expand All @@ -37,15 +37,15 @@ repos:
args: ['--config=.flake8']

- repo: https://github.com/omnilib/ufmt
rev: v2.3.0
rev: v2.8.0
hooks:
- id: ufmt
additional_dependencies:
- black == 22.12.0
- usort == 1.0.5

- repo: https://github.com/jsh9/pydoclint
rev: 94efc5f989adbea30f3534b476b2931a02c1af90
rev: 0.5.12
hooks:
- id: pydoclint
args: [--config=pyproject.toml]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ target-version = ["py38"]
[tool.pydoclint]
style = 'google'
check-return-types = 'False'
exclude = 'tests/torchtune/models/(\w+)/scripts/'
exclude = 'tests/torchtune/models/(\w+)/scripts/|recipes/|torchtune/modules/_export'

[tool.pytest.ini_options]
addopts = ["--showlocals", "--import-mode=prepend", "--without-integration", "--without-slow-integration"]
Expand Down
22 changes: 12 additions & 10 deletions torchtune/data/_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,14 @@ def padded_collate(
padding values.

Returns:
torch.Tensor: The padded tensor of input ids with shape [batch_size, max_seq_len].
torch.Tensor: The padded tensor of input ids with shape ``[batch_size, max_seq_len]``.

Raises:
ValueError: if ``pad_direction`` is not one of "left" or "right".
ValueError: if ``keys_to_pad`` is empty, or is not a list, or is not a subset of keys in the batch.
ValueError: if ``padding_idx`` is provided as a dictionary, but the keys are not identical to
``keys_to_pad``.
ValueError:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring linter lints based on unique raises, so we document a single value error and each condition it will raise for. Annoyingly, I couldn't for the life of me get lists to properly render in the raises section.

This is what it looks like instead:

image

If ``pad_direction`` is not one of "left" or "right", **or**
if ``keys_to_pad`` is empty, or is not a list, **or**
if ``keys_to_pad`` is not a subset of keys in the batch, **or**
if ``padding_idx`` is provided as a dictionary, but the keys are not identical to ``keys_to_pad``

Example:
>>> a = [1, 2, 3]
Expand Down Expand Up @@ -149,9 +150,9 @@ def padded_collate(
output_dict[k] = pad_fn(
[torch.tensor(x[k]) for x in batch],
batch_first=True,
padding_value=padding_idx[k]
if isinstance(padding_idx, dict)
else padding_idx,
padding_value=(
padding_idx[k] if isinstance(padding_idx, dict) else padding_idx
),
)
return output_dict

Expand Down Expand Up @@ -274,8 +275,9 @@ def padded_collate_tiled_images_and_mask(
- aspect_ratio: Tensor of shape (bsz, max_num_images, 2)

Raises:
ValueError: if ``pad_direction`` is not one of "left" or "right".
ValueError: if pad_max_tiles is set to a value less than the largest number of tiles in an image.
ValueError:
If ``pad_direction`` is not one of "left" or "right", **or**
if pad_max_tiles is set to a value less than the largest number of tiles in an image.

Example:
>>> image_id = 1
Expand Down
7 changes: 4 additions & 3 deletions torchtune/data/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ class InputOutputToMessages(Transform):
on a remote url. For text-only, leave as None. Default is None.

Raises:
ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or
``output`` not in ``column_map``.
ValueError: If ``image_dir`` is provided but ``image`` not in ``column_map``.
ValueError:
If ``column_map`` is provided and ``input`` not in ``column_map``, or
``output`` not in ``column_map``, **or**
if ``image_dir`` is provided but ``image`` not in ``column_map``.
"""

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions torchtune/data/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def load_image(image_loc: Union[Path, str]) -> "PIL.Image.Image":
to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg".

Raises:
ValueError: If the image cannot be loaded from remote source.
ValueError: If the image cannot be opened as a :class:`~PIL.Image.Image`.
ValueError:
If the image cannot be loaded from remote source, **or**
if the image cannot be opened as a :class:`~PIL.Image.Image`.

Examples:
>>> # Load from remote source
Expand Down
20 changes: 11 additions & 9 deletions torchtune/models/clip/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,13 @@ def _load_state_dict_hook(
**kwargs (Dict[str, Any]): Additional keyword arguments.

Raises:
ValueError: if loaded local or global embedding n_tokens_per_tile is not derived
from a squared grid.
ValueError: if after interpolation, the shape of the loaded local embedding
is not compatible with the current embedding.
ValueError: if after interpolation, the shape of the loaded global embedding
is not compatible with the current embedding.
ValueError:
If loaded local or global embedding n_tokens_per_tile is not derived
from a squared grid, **or**
if after interpolation, the shape of the loaded local embedding
is not compatible with the current embedding, **or**
if after interpolation, the shape of the loaded global embedding
is not compatible with the current embedding.
"""

# process local_token_positional_embedding
Expand Down Expand Up @@ -530,9 +531,10 @@ def _load_state_dict_hook(
**kwargs (Dict[str, Any]): Additional keyword arguments.

Raises:
ValueError: if the shape of the loaded embedding is not compatible with the current embedding.
ValueError: if max_num_tiles_x, max_num_tiles_y are not equal.
ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding.
ValueError:
If the shape of the loaded embedding is not compatible with the current embedding, **or**
if ``max_num_tiles_x``, ``max_num_tiles_y`` are not equal, **or**
if after interpolation, the shape of the loaded embedding is not compatible with the current embedding.
"""

embedding = state_dict.get(prefix + "embedding")
Expand Down
16 changes: 11 additions & 5 deletions torchtune/models/gemma2/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ class Gemma2Attention(nn.Module):
softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed
query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead
Raises:
ValueError: If ``num_heads % num_kv_heads != 0``
ValueError: If ``embed_dim % num_heads != 0``
ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1``
ValueError: if q_norm is defined without k_norm or vice versa
ValueError:
If ``num_heads % num_kv_heads != 0``, **or**
if ``embed_dim % num_heads != 0``, **or**
if ``attn_dropout < 0`` or ``attn_dropout > 1``, **or**
if ``q_norm`` is defined without k_norm or vice versa
"""

def __init__(
Expand Down Expand Up @@ -156,7 +157,11 @@ def setup_cache(
self.cache_enabled = True

def reset_cache(self):
"""Reset the key value caches."""
"""Reset the key value caches.

Raises:
RuntimeError: if key value caches are not already setup.
"""
if self.kv_cache is None:
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
Expand Down Expand Up @@ -196,6 +201,7 @@ def forward(
If none, assume the index of the token is its position id. Default is None.

Raises:
NotImplementedError: If ``mask`` is provided, but mask is not an instance of ``torch.Tensor``.
ValueError: If no ``y`` input and ``kv_cache`` is not enabled.

Returns:
Expand Down
1 change: 1 addition & 0 deletions torchtune/models/phi3/_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def tokenize_messages(

Raises:
ValueError: If the role is not "user", "assistant", or "system".
RuntimeError: If ``message["type"] != "text``.

Returns:
Tuple[List[int], List[bool]]: The tokenized messages
Expand Down
9 changes: 5 additions & 4 deletions torchtune/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ class MultiHeadAttention(nn.Module):
Default value is 0.0.

Raises:
ValueError: If ``num_heads % num_kv_heads != 0``
ValueError: If ``embed_dim % num_heads != 0``
ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1``
ValueError: if q_norm is defined without k_norm or vice versa
ValueError:
If ``num_heads % num_kv_heads != 0``, **or**
if ``embed_dim % num_heads != 0``, **or**
if ``attn_dropout < 0`` or ``attn_dropout > 1``, **or**
if q_norm is defined without k_norm or vice versa
"""

def __init__(
Expand Down
6 changes: 5 additions & 1 deletion torchtune/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,13 @@ def update(
Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively.

Raises:
AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length.
ValueError: if the batch size of the new key (or value) tensor is greater than the batch size
used during cache setup.

Note:
This function will raise an ``AssertionError`` if the sequence length of ``k_val``
is longer than the maximum cache sequence length.

"""
bsz, _, seq_len, _ = k_val.shape
if bsz > self.k_cache.shape[0]:
Expand Down
9 changes: 5 additions & 4 deletions torchtune/modules/peft/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,11 @@ def validate_missing_and_unexpected_for_lora(
None

Raises:
AssertionError: if base_missing contains any base model keys.
AssertionError: if base_unexpected is nonempty.
AssertionError: if lora_missing contains any LoRA keys.
AssertionError: if lora_unexpected is nonempty.
AssertionError:
If base_missing contains any base model keys, **or**
if base_unexpected is nonempty, **or**
if lora_missing contains any LoRA keys, **or**
if lora_unexpected is nonempty.
"""
lora_modules = get_lora_module_names(
lora_attn_modules, apply_lora_to_mlp, apply_lora_to_output
Expand Down
14 changes: 8 additions & 6 deletions torchtune/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,9 @@ class TransformerDecoder(nn.Module):
output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output

Raises:
AssertionError: num_layers is set and layer is a list
AssertionError: num_layers is not set and layer is an nn.Module
AssertionError:
If ``num_layers`` is set and layer is a list, **or**
``num_layers`` is not set and layer is an ``nn.Module``.

Note:
Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
Expand Down Expand Up @@ -519,10 +520,11 @@ def _validate_inputs(
input_pos (Optional[torch.Tensor]): Input tensor position IDs.

Raises:
ValueError: if seq_len of x is bigger than max_seq_len
ValueError: if the model has caches which have been setup with self-attention layers and ``mask`` is not provided.
ValueError: if the model has caches which have been setup with encoder layers and ``encoder_mask`` is not provided.
ValueError: if the model has caches which have been setup ``input_pos`` is not provided.
ValueError:
If seq_len of x is bigger than max_seq_len, **or**
if the model has caches which have been setup with self-attention layers and ``mask`` is not provided, **or**
if the model has caches which have been setup with encoder layers and ``encoder_mask`` is not provided, **or**
if the model has caches which have been setup ``input_pos`` is not provided.
"""

if seq_len > self.max_seq_len:
Expand Down
7 changes: 4 additions & 3 deletions torchtune/modules/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,10 @@ class VisionTransformer(nn.Module):
Default is False, which adds CLS token to the beginning of the sequence.

Raises:
ValueError: If `tile_size` is not greater than 0.
ValueError: If `patch_size` is not greater than 0.
ValueError: If `len(out_indices)` is greater than `num_layers`.
ValueError:
If `tile_size` is not greater than 0, **or**
if `patch_size` is not greater than 0, **or**
if `len(out_indices)` is greater than `num_layers`.
"""

def __init__(
Expand Down
1 change: 0 additions & 1 deletion torchtune/training/_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class OffloadActivations(saved_tensors_hooks):

Raises:
ValueError: if max_fwd_stash_size is not at least 1.
RuntimeError: if use_streams but torch installation is earlier than torch-2.5.0.dev20240907

Example:
>>> with OffloadActivations():
Expand Down
1 change: 0 additions & 1 deletion torchtune/training/checkpointing/_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,6 @@ class FullModelMetaCheckpointer(_CheckpointerInterface):

Raises:
ValueError: If ``checkpoint_files`` is not a list of length 1
ValueError: If ``should_load_recipe_state`` is True but ``recipe_checkpoint`` is None
"""

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions torchtune/training/checkpointing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,9 @@ def update_state_dict_for_classifier(
if ``output.weight != model.output.weight``.

Raises:
AssertionError: if ``state_dict`` does not contain ``output.weight``.
AssertionError: if ``model_named_parameters`` does not contain ``output.weight``.
AssertionError:
If ``state_dict`` does not contain ``output.weight``, **or**
if ``model_named_parameters`` does not contain ``output.weight``.

"""
output_weight = dict(model_named_parameters).get("output.weight", None)
Expand Down
5 changes: 3 additions & 2 deletions torchtune/utils/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,11 @@ def batch_to_device(batch: dict, device: torch.device) -> None:

Args:
batch (dict): dict of Tensors or more nested dicts of tensors.
device (torch.device): torch device to move the tensor's too
device (torch.device): torch device to move the tensors to.

Raises:
AttributeError: if batch dict contains anything other than tensors
ValueError: if batch dict contains anything other than ``torch.Tensor``.

"""
for k, v in batch.items():
if isinstance(v, dict):
Expand Down
Loading