From 4ebf618d95c3e9ec5b8388878499eab52984c2e1 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 17 Dec 2024 11:54:40 +0000 Subject: [PATCH 1/4] fixing linter --- .pre-commit-config.yaml | 10 ++++----- pyproject.toml | 2 +- torchtune/data/_collate.py | 22 ++++++++++--------- torchtune/data/_messages.py | 7 +++--- torchtune/data/_utils.py | 5 +++-- torchtune/models/clip/_position_embeddings.py | 20 +++++++++-------- torchtune/models/gemma2/_attention.py | 16 +++++++++----- torchtune/models/phi3/_tokenizer.py | 1 + .../modules/_export/_position_embeddings.py | 20 +++++++++-------- torchtune/modules/_export/attention.py | 9 ++++---- torchtune/modules/_export/kv_cache.py | 2 +- torchtune/modules/attention.py | 9 ++++---- torchtune/modules/kv_cache.py | 6 ++++- torchtune/modules/peft/_utils.py | 9 ++++---- torchtune/modules/transformer.py | 14 +++++++----- torchtune/modules/vision_transformer.py | 7 +++--- torchtune/training/_activation_offloading.py | 8 +++---- .../training/checkpointing/_checkpointer.py | 1 - torchtune/training/checkpointing/_utils.py | 5 +++-- torchtune/utils/_device.py | 2 +- 20 files changed, 100 insertions(+), 75 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3f8e3149fc..854ee0e97c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ default_language_version: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: 6306a48f7dae5861702d573c9c247e4e9498e867 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: check-ast @@ -18,7 +18,7 @@ repos: exclude: '^(.*\.svg)$' - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.5.4 + rev: v1.5.5 hooks: - id: insert-license files: \.py$|\.sh$ @@ -27,7 +27,7 @@ repos: - docs/license_header.txt - repo: https://github.com/pycqa/flake8 - rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b + rev: 7.1.1 hooks: - id: flake8 additional_dependencies: @@ -37,7 +37,7 @@ repos: args: ['--config=.flake8'] - repo: https://github.com/omnilib/ufmt - rev: v2.3.0 + rev: v2.8.0 hooks: - id: ufmt additional_dependencies: @@ -45,7 +45,7 @@ repos: - usort == 1.0.5 - repo: https://github.com/jsh9/pydoclint - rev: 94efc5f989adbea30f3534b476b2931a02c1af90 + rev: 0.5.12 hooks: - id: pydoclint args: [--config=pyproject.toml] diff --git a/pyproject.toml b/pyproject.toml index 87ed1fb89e..f94732b58a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ target-version = ["py38"] [tool.pydoclint] style = 'google' check-return-types = 'False' -exclude = 'tests/torchtune/models/(\w+)/scripts/' +exclude = 'tests/torchtune/models/(\w+)/scripts/|recipes/|torchtune/modules/_export' [tool.pytest.ini_options] addopts = ["--showlocals", "--import-mode=prepend", "--without-integration", "--without-slow-integration"] diff --git a/torchtune/data/_collate.py b/torchtune/data/_collate.py index 5157f4a7fa..410ad49376 100644 --- a/torchtune/data/_collate.py +++ b/torchtune/data/_collate.py @@ -81,13 +81,14 @@ def padded_collate( padding values. Returns: - torch.Tensor: The padded tensor of input ids with shape [batch_size, max_seq_len]. + torch.Tensor: The padded tensor of input ids with shape ``[batch_size, max_seq_len]``. Raises: - ValueError: if ``pad_direction`` is not one of "left" or "right". - ValueError: if ``keys_to_pad`` is empty, or is not a list, or is not a subset of keys in the batch. - ValueError: if ``padding_idx`` is provided as a dictionary, but the keys are not identical to - ``keys_to_pad``. + ValueError: + If ``pad_direction`` is not one of "left" or "right", **or** + if ``keys_to_pad`` is empty, or is not a list, **or** + if ``keys_to_pad`` is not a subset of keys in the batch, **or** + if ``padding_idx`` is provided as a dictionary, but the keys are not identical to ``keys_to_pad`` Example: >>> a = [1, 2, 3] @@ -149,9 +150,9 @@ def padded_collate( output_dict[k] = pad_fn( [torch.tensor(x[k]) for x in batch], batch_first=True, - padding_value=padding_idx[k] - if isinstance(padding_idx, dict) - else padding_idx, + padding_value=( + padding_idx[k] if isinstance(padding_idx, dict) else padding_idx + ), ) return output_dict @@ -274,8 +275,9 @@ def padded_collate_tiled_images_and_mask( - aspect_ratio: Tensor of shape (bsz, max_num_images, 2) Raises: - ValueError: if ``pad_direction`` is not one of "left" or "right". - ValueError: if pad_max_tiles is set to a value less than the largest number of tiles in an image. + ValueError: + If ``pad_direction`` is not one of "left" or "right", **or** + if pad_max_tiles is set to a value less than the largest number of tiles in an image. Example: >>> image_id = 1 diff --git a/torchtune/data/_messages.py b/torchtune/data/_messages.py index bbd3ae5981..a4e00834c2 100644 --- a/torchtune/data/_messages.py +++ b/torchtune/data/_messages.py @@ -168,9 +168,10 @@ class InputOutputToMessages(Transform): on a remote url. For text-only, leave as None. Default is None. Raises: - ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or - ``output`` not in ``column_map``. - ValueError: If ``image_dir`` is provided but ``image`` not in ``column_map``. + ValueError: + If ``column_map`` is provided and ``input`` not in ``column_map``, or + ``output`` not in ``column_map``, **or** + if ``image_dir`` is provided but ``image`` not in ``column_map``. """ def __init__( diff --git a/torchtune/data/_utils.py b/torchtune/data/_utils.py index 812d1617a1..6a266cebbe 100644 --- a/torchtune/data/_utils.py +++ b/torchtune/data/_utils.py @@ -57,8 +57,9 @@ def load_image(image_loc: Union[Path, str]) -> "PIL.Image.Image": to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg". Raises: - ValueError: If the image cannot be loaded from remote source. - ValueError: If the image cannot be opened as a :class:`~PIL.Image.Image`. + ValueError: + If the image cannot be loaded from remote source, **or** + if the image cannot be opened as a :class:`~PIL.Image.Image`. Examples: >>> # Load from remote source diff --git a/torchtune/models/clip/_position_embeddings.py b/torchtune/models/clip/_position_embeddings.py index 09a98862e1..50488a5380 100644 --- a/torchtune/models/clip/_position_embeddings.py +++ b/torchtune/models/clip/_position_embeddings.py @@ -126,12 +126,13 @@ def _load_state_dict_hook( **kwargs (Dict[str, Any]): Additional keyword arguments. Raises: - ValueError: if loaded local or global embedding n_tokens_per_tile is not derived - from a squared grid. - ValueError: if after interpolation, the shape of the loaded local embedding - is not compatible with the current embedding. - ValueError: if after interpolation, the shape of the loaded global embedding - is not compatible with the current embedding. + ValueError: + If loaded local or global embedding n_tokens_per_tile is not derived + from a squared grid, **or** + if after interpolation, the shape of the loaded local embedding + is not compatible with the current embedding, **or** + if after interpolation, the shape of the loaded global embedding + is not compatible with the current embedding. """ # process local_token_positional_embedding @@ -530,9 +531,10 @@ def _load_state_dict_hook( **kwargs (Dict[str, Any]): Additional keyword arguments. Raises: - ValueError: if the shape of the loaded embedding is not compatible with the current embedding. - ValueError: if max_num_tiles_x, max_num_tiles_y are not equal. - ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. + ValueError: + If the shape of the loaded embedding is not compatible with the current embedding, **or** + if ``max_num_tiles_x``, ``max_num_tiles_y`` are not equal, **or** + if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. """ embedding = state_dict.get(prefix + "embedding") diff --git a/torchtune/models/gemma2/_attention.py b/torchtune/models/gemma2/_attention.py index 1cd3bfdc12..0ed71b34d6 100644 --- a/torchtune/models/gemma2/_attention.py +++ b/torchtune/models/gemma2/_attention.py @@ -50,10 +50,11 @@ class Gemma2Attention(nn.Module): softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead Raises: - ValueError: If ``num_heads % num_kv_heads != 0`` - ValueError: If ``embed_dim % num_heads != 0`` - ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` - ValueError: if q_norm is defined without k_norm or vice versa + ValueError: + If ``num_heads % num_kv_heads != 0``, **or** + if ``embed_dim % num_heads != 0``, **or** + if ``attn_dropout < 0`` or ``attn_dropout > 1``, **or** + if ``q_norm`` is defined without k_norm or vice versa """ def __init__( @@ -156,7 +157,11 @@ def setup_cache( self.cache_enabled = True def reset_cache(self): - """Reset the key value caches.""" + """Reset the key value caches. + + Raises: + RuntimeError: if key value caches are not already setup. + """ if self.kv_cache is None: raise RuntimeError( "Key value caches are not setup. Call ``setup_caches()`` first." @@ -196,6 +201,7 @@ def forward( If none, assume the index of the token is its position id. Default is None. Raises: + NotImplementedError: If ``mask`` is provided, but mask is not an instance of ``torch.Tensor``. ValueError: If no ``y`` input and ``kv_cache`` is not enabled. Returns: diff --git a/torchtune/models/phi3/_tokenizer.py b/torchtune/models/phi3/_tokenizer.py index b48b1d93a3..38707bf26e 100644 --- a/torchtune/models/phi3/_tokenizer.py +++ b/torchtune/models/phi3/_tokenizer.py @@ -157,6 +157,7 @@ def tokenize_messages( Raises: ValueError: If the role is not "user", "assistant", or "system". + RuntimeError: If ``message["type"] != "text``. Returns: Tuple[List[int], List[bool]]: The tokenized messages diff --git a/torchtune/modules/_export/_position_embeddings.py b/torchtune/modules/_export/_position_embeddings.py index 0489b7f345..bd4d14e516 100644 --- a/torchtune/modules/_export/_position_embeddings.py +++ b/torchtune/modules/_export/_position_embeddings.py @@ -73,9 +73,10 @@ def _load_state_dict_hook( **kwargs (Dict[str, Any]): Additional keyword arguments. Raises: - ValueError: if the shape of the loaded embedding is not compatible with the current embedding. - ValueError: if max_num_tiles_x, max_num_tiles_y are not equal. - ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. + ValueError: + If the shape of the loaded embedding is not compatible with the current embedding, **or** + if max_num_tiles_x, max_num_tiles_y are not equal, **or** + if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. """ embedding = state_dict.get(prefix + "embedding") @@ -302,12 +303,13 @@ def _load_state_dict_hook( **kwargs (Dict[str, Any]): Additional keyword arguments. Raises: - ValueError: if loaded local or global embedding n_tokens_per_tile is not derived - from a squared grid. - ValueError: if after interpolation, the shape of the loaded local embedding - is not compatible with the current embedding. - ValueError: if after interpolation, the shape of the loaded global embedding - is not compatible with the current embedding. + ValueError: + If loaded local or global embedding n_tokens_per_tile is not derived + from a squared grid, **or** + if after interpolation, the shape of the loaded local embedding + is not compatible with the current embedding, **or** + if after interpolation, the shape of the loaded global embedding + is not compatible with the current embedding. """ # process local_token_positional_embedding diff --git a/torchtune/modules/_export/attention.py b/torchtune/modules/_export/attention.py index bb3fe4a94b..352f97d9c1 100644 --- a/torchtune/modules/_export/attention.py +++ b/torchtune/modules/_export/attention.py @@ -93,10 +93,11 @@ class MultiHeadAttention(nn.Module): Default value is 0.0. Raises: - ValueError: If ``num_heads % num_kv_heads != 0`` - ValueError: If ``embed_dim % num_heads != 0`` - ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` - ValueError: if q_norm is defined without k_norm or vice versa + ValueError: + If ``num_heads % num_kv_heads != 0``, **or** + If ``embed_dim % num_heads != 0``, **or** + If ``attn_dropout < 0`` or ``attn_dropout > 1``, **or** + if q_norm is defined without k_norm or vice versa """ def __init__( diff --git a/torchtune/modules/_export/kv_cache.py b/torchtune/modules/_export/kv_cache.py index 8e0b7047e5..ad41de8859 100644 --- a/torchtune/modules/_export/kv_cache.py +++ b/torchtune/modules/_export/kv_cache.py @@ -95,7 +95,7 @@ def update( Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. Raises: - AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. + AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. #noqa ValueError: if the batch size of the new key (or value) tensor is greater than the batch size used during cache setup. """ diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index f4cfa142f7..ff6faccb5d 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -73,10 +73,11 @@ class MultiHeadAttention(nn.Module): Default value is 0.0. Raises: - ValueError: If ``num_heads % num_kv_heads != 0`` - ValueError: If ``embed_dim % num_heads != 0`` - ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` - ValueError: if q_norm is defined without k_norm or vice versa + ValueError: + If ``num_heads % num_kv_heads != 0``, **or** + if ``embed_dim % num_heads != 0``, **or** + if ``attn_dropout < 0`` or ``attn_dropout > 1``, **or** + if q_norm is defined without k_norm or vice versa """ def __init__( diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index e96491c22a..22207d1b4f 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -84,9 +84,13 @@ def update( Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. Raises: - AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. ValueError: if the batch size of the new key (or value) tensor is greater than the batch size used during cache setup. + + Note: + This function will raise an ``AssertionError`` if the sequence length of ``k_val`` + is longer than the maximum cache sequence length. + """ bsz, _, seq_len, _ = k_val.shape if bsz > self.k_cache.shape[0]: diff --git a/torchtune/modules/peft/_utils.py b/torchtune/modules/peft/_utils.py index 1d0f1047b6..d9be9667ea 100644 --- a/torchtune/modules/peft/_utils.py +++ b/torchtune/modules/peft/_utils.py @@ -285,10 +285,11 @@ def validate_missing_and_unexpected_for_lora( None Raises: - AssertionError: if base_missing contains any base model keys. - AssertionError: if base_unexpected is nonempty. - AssertionError: if lora_missing contains any LoRA keys. - AssertionError: if lora_unexpected is nonempty. + AssertionError: + If base_missing contains any base model keys, **or** + if base_unexpected is nonempty, **or** + if lora_missing contains any LoRA keys, **or** + if lora_unexpected is nonempty. """ lora_modules = get_lora_module_names( lora_attn_modules, apply_lora_to_mlp, apply_lora_to_output diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index 66ac92002f..bf3bd93454 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -344,8 +344,9 @@ class TransformerDecoder(nn.Module): output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output Raises: - AssertionError: num_layers is set and layer is a list - AssertionError: num_layers is not set and layer is an nn.Module + AssertionError: + If ``num_layers`` is set and layer is a list, **or** + ``num_layers`` is not set and layer is an ``nn.Module``. Note: Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1]) @@ -519,10 +520,11 @@ def _validate_inputs( input_pos (Optional[torch.Tensor]): Input tensor position IDs. Raises: - ValueError: if seq_len of x is bigger than max_seq_len - ValueError: if the model has caches which have been setup with self-attention layers and ``mask`` is not provided. - ValueError: if the model has caches which have been setup with encoder layers and ``encoder_mask`` is not provided. - ValueError: if the model has caches which have been setup ``input_pos`` is not provided. + ValueError: + If seq_len of x is bigger than max_seq_len, **or** + if the model has caches which have been setup with self-attention layers and ``mask`` is not provided, **or** + if the model has caches which have been setup with encoder layers and ``encoder_mask`` is not provided, **or** + if the model has caches which have been setup ``input_pos`` is not provided. """ if seq_len > self.max_seq_len: diff --git a/torchtune/modules/vision_transformer.py b/torchtune/modules/vision_transformer.py index d44d0c930f..8ba14dfe45 100644 --- a/torchtune/modules/vision_transformer.py +++ b/torchtune/modules/vision_transformer.py @@ -190,9 +190,10 @@ class VisionTransformer(nn.Module): Default is False, which adds CLS token to the beginning of the sequence. Raises: - ValueError: If `tile_size` is not greater than 0. - ValueError: If `patch_size` is not greater than 0. - ValueError: If `len(out_indices)` is greater than `num_layers`. + ValueError: + If `tile_size` is not greater than 0, **or** + if `patch_size` is not greater than 0, **or** + if `len(out_indices)` is greater than `num_layers`. """ def __init__( diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index a802ce98d8..6ca64c5c32 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -55,7 +55,6 @@ class OffloadActivations(saved_tensors_hooks): Raises: ValueError: if max_fwd_stash_size is not at least 1. - RuntimeError: if use_streams but torch installation is earlier than torch-2.5.0.dev20240907 Example: >>> with OffloadActivations(): @@ -134,9 +133,10 @@ def get_num_bytes_tensor(x: torch.Tensor) -> int: def pack_tensor(activation: torch.Tensor) -> int: # activations are passed in during forward pass - from here we take over and return a unique id if self.is_first_forward_call: - assert ( - len(self.tracker) == 0 - ), "backward pass should have cleared tracker of all tensors" + if len(self.tracker) == 0: + raise AssertionError( + "backward pass should have cleared tracker of all tensors" + ) # set training phase trackers self.is_first_forward_call = False diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index a5d72af320..141b513a2f 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -932,7 +932,6 @@ class FullModelMetaCheckpointer(_CheckpointerInterface): Raises: ValueError: If ``checkpoint_files`` is not a list of length 1 - ValueError: If ``should_load_recipe_state`` is True but ``recipe_checkpoint`` is None """ def __init__( diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 1d8a63daab..ed6e819076 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -280,8 +280,9 @@ def update_state_dict_for_classifier( if ``output.weight != model.output.weight``. Raises: - AssertionError: if ``state_dict`` does not contain ``output.weight``. - AssertionError: if ``model_named_parameters`` does not contain ``output.weight``. + AssertionError: + If ``state_dict`` does not contain ``output.weight``, **or** + if ``model_named_parameters`` does not contain ``output.weight``. """ output_weight = dict(model_named_parameters).get("output.weight", None) diff --git a/torchtune/utils/_device.py b/torchtune/utils/_device.py index 10d5e62a05..4899ee7a42 100644 --- a/torchtune/utils/_device.py +++ b/torchtune/utils/_device.py @@ -180,7 +180,7 @@ def batch_to_device(batch: dict, device: torch.device) -> None: device (torch.device): torch device to move the tensor's too Raises: - AttributeError: if batch dict contains anything other than tensors + ValueError: if batch dict contains anything other than ``torch.Tensors``s. """ for k, v in batch.items(): if isinstance(v, dict): From 409d3afaff167265ff6b281c8f25a7292f905df4 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 17 Dec 2024 12:00:48 +0000 Subject: [PATCH 2/4] fixing docs --- torchtune/utils/_device.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtune/utils/_device.py b/torchtune/utils/_device.py index 4899ee7a42..1d6defefae 100644 --- a/torchtune/utils/_device.py +++ b/torchtune/utils/_device.py @@ -177,10 +177,11 @@ def batch_to_device(batch: dict, device: torch.device) -> None: Args: batch (dict): dict of Tensors or more nested dicts of tensors. - device (torch.device): torch device to move the tensor's too + device (torch.device): torch device to move the tensors to. Raises: - ValueError: if batch dict contains anything other than ``torch.Tensors``s. + ValueError: if batch dict contains anything other than ``torch.Tensor``. + """ for k, v in batch.items(): if isinstance(v, dict): From 0ea49e49d51901277679709511e5bf95624c58d2 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 17 Dec 2024 12:23:07 +0000 Subject: [PATCH 3/4] reverting act off --- torchtune/training/_activation_offloading.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py index 6ca64c5c32..8d32083b54 100644 --- a/torchtune/training/_activation_offloading.py +++ b/torchtune/training/_activation_offloading.py @@ -133,10 +133,9 @@ def get_num_bytes_tensor(x: torch.Tensor) -> int: def pack_tensor(activation: torch.Tensor) -> int: # activations are passed in during forward pass - from here we take over and return a unique id if self.is_first_forward_call: - if len(self.tracker) == 0: - raise AssertionError( - "backward pass should have cleared tracker of all tensors" - ) + assert ( + len(self.tracker) == 0 + ), "backward pass should have cleared tracker of all tensors" # set training phase trackers self.is_first_forward_call = False From f75c7cd15253fb9561aa46bc5b38d685d34208b7 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 17 Dec 2024 12:24:39 +0000 Subject: [PATCH 4/4] reverting export --- .../modules/_export/_position_embeddings.py | 20 +++++++++---------- torchtune/modules/_export/attention.py | 9 ++++----- torchtune/modules/_export/kv_cache.py | 2 +- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/torchtune/modules/_export/_position_embeddings.py b/torchtune/modules/_export/_position_embeddings.py index bd4d14e516..0489b7f345 100644 --- a/torchtune/modules/_export/_position_embeddings.py +++ b/torchtune/modules/_export/_position_embeddings.py @@ -73,10 +73,9 @@ def _load_state_dict_hook( **kwargs (Dict[str, Any]): Additional keyword arguments. Raises: - ValueError: - If the shape of the loaded embedding is not compatible with the current embedding, **or** - if max_num_tiles_x, max_num_tiles_y are not equal, **or** - if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. + ValueError: if the shape of the loaded embedding is not compatible with the current embedding. + ValueError: if max_num_tiles_x, max_num_tiles_y are not equal. + ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. """ embedding = state_dict.get(prefix + "embedding") @@ -303,13 +302,12 @@ def _load_state_dict_hook( **kwargs (Dict[str, Any]): Additional keyword arguments. Raises: - ValueError: - If loaded local or global embedding n_tokens_per_tile is not derived - from a squared grid, **or** - if after interpolation, the shape of the loaded local embedding - is not compatible with the current embedding, **or** - if after interpolation, the shape of the loaded global embedding - is not compatible with the current embedding. + ValueError: if loaded local or global embedding n_tokens_per_tile is not derived + from a squared grid. + ValueError: if after interpolation, the shape of the loaded local embedding + is not compatible with the current embedding. + ValueError: if after interpolation, the shape of the loaded global embedding + is not compatible with the current embedding. """ # process local_token_positional_embedding diff --git a/torchtune/modules/_export/attention.py b/torchtune/modules/_export/attention.py index 352f97d9c1..bb3fe4a94b 100644 --- a/torchtune/modules/_export/attention.py +++ b/torchtune/modules/_export/attention.py @@ -93,11 +93,10 @@ class MultiHeadAttention(nn.Module): Default value is 0.0. Raises: - ValueError: - If ``num_heads % num_kv_heads != 0``, **or** - If ``embed_dim % num_heads != 0``, **or** - If ``attn_dropout < 0`` or ``attn_dropout > 1``, **or** - if q_norm is defined without k_norm or vice versa + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa """ def __init__( diff --git a/torchtune/modules/_export/kv_cache.py b/torchtune/modules/_export/kv_cache.py index ad41de8859..8e0b7047e5 100644 --- a/torchtune/modules/_export/kv_cache.py +++ b/torchtune/modules/_export/kv_cache.py @@ -95,7 +95,7 @@ def update( Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. Raises: - AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. #noqa + AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. ValueError: if the batch size of the new key (or value) tensor is greater than the batch size used during cache setup. """