Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions docs/source/tutorials/lora_finetune.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ Let's take a look at a minimal implementation of LoRA in native PyTorch.

.. code-block:: python

from torch import nn, Tensor
import torch
from torch import nn

class LoRALinear(nn.Module):
def __init__(
Expand Down Expand Up @@ -114,7 +115,7 @@ Let's take a look at a minimal implementation of LoRA in native PyTorch.
self.lora_a.weight.requires_grad = True
self.lora_b.weight.requires_grad = True

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# This would be the output of the original model
frozen_out = self.linear(x)

Expand Down
5 changes: 3 additions & 2 deletions docs/source/tutorials/qlora_finetune.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial <lora_finetune_
.. code-block:: python
:emphasize-lines: 3, 13, 19, 20, 39, 40, 41

from torch import nn, Tensor
import torch
from torch import nn
import torch.nn.functional as F
from torchao.dtypes.nf4tensor import linear_nf4, to_nf4

Expand Down Expand Up @@ -253,7 +254,7 @@ a vanilla minimal LoRA layer, taken from :ref:`the LoRA tutorial <lora_finetune_
self.lora_a.weight.requires_grad = True
self.lora_b.weight.requires_grad = True

def forward(self, x: Tensor) -> Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
# frozen_out would be the output of the original model
if quantize_base:
# Call into torchao's linear_nf4 to run linear forward pass w/quantized weight.
Expand Down
11 changes: 6 additions & 5 deletions tests/torchtune/models/llama2/scripts/compare_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import nn, Tensor
from torchtune.modules import KVCache, MultiHeadAttention, RotaryPositionalEmbeddings


# Copy-paste of fused attention for comparison
class FusedMultiHeadAttention(nn.Module):
"""Multi-headed grouped query self-attention (GQA) layer introduced
Expand Down Expand Up @@ -115,15 +116,15 @@ def __init__(

def forward(
self,
x: Tensor,
mask: Optional[Tensor] = None,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
curr_pos: int = 0,
) -> Tensor:
) -> torch.Tensor:
"""
Args:
x (Tensor): input tensor with shape
[batch_size x seq_length x embed_dim]
mask (Optional[Tensor]): boolean mask, defaults to None.
mask (Optional[torch.Tensor]): boolean mask, defaults to None.
curr_pos (int): current position in the sequence, defaults to 0.

Returns:
Expand Down Expand Up @@ -241,7 +242,7 @@ def map_state_dict(
return mapped_sd


def _get_mask(inpt: Tensor) -> Tensor:
def _get_mask(inpt: torch.Tensor) -> torch.Tensor:
seq_len = inpt.shape[1]
mask = torch.full((1, 1, seq_len, seq_len), float("-inf"), device=inpt.device)
mask = torch.triu(mask, diagonal=1).type_as(inpt)
Expand Down
24 changes: 12 additions & 12 deletions tests/torchtune/modules/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch

from tests.test_utils import assert_expected, fixed_init_model
from torch import nn, Tensor
from torch import nn

from torchtune.modules import KVCache, MultiHeadAttention, RotaryPositionalEmbeddings
from torchtune.utils.seed import set_seed
Expand Down Expand Up @@ -40,7 +40,7 @@ def input_params(self) -> Tuple[int, int, int]:
return batch_size, seq_len, embed_dim

@pytest.fixture
def input(self, input_params: Tuple[int, int, int]) -> Tensor:
def input(self, input_params: Tuple[int, int, int]) -> torch.Tensor:
batch_size, seq_len, embed_dim = input_params
x = torch.randn(batch_size, seq_len, embed_dim)
return x
Expand All @@ -58,7 +58,7 @@ def input_max_len_exceeded(
self,
input_params: Tuple[int, int, int],
attn_params_gqa: Tuple[int, int, int, int],
) -> Tensor:
) -> torch.Tensor:
batch_size, seq_len, embed_dim = input_params
_, _, _, max_seq_len = attn_params_gqa
seq_len = max_seq_len + 1
Expand All @@ -69,7 +69,7 @@ def input_max_bs_exceeded(
self,
input_params: Tuple[int, int, int],
attn_params_gqa: Tuple[int, int, int, int],
) -> Tensor:
) -> torch.Tensor:
batch_size, seq_len, embed_dim = input_params
_, _, _, max_seq_len = attn_params_gqa
batch_size += 1
Expand Down Expand Up @@ -253,7 +253,7 @@ def mqa_kv_cache(
attn.eval()
return attn

def test_forward_gqa(self, input: Tensor, gqa: MultiHeadAttention) -> None:
def test_forward_gqa(self, input: torch.Tensor, gqa: MultiHeadAttention) -> None:
with torch.no_grad():
output = gqa(input)
assert_expected(
Expand All @@ -262,7 +262,7 @@ def test_forward_gqa(self, input: Tensor, gqa: MultiHeadAttention) -> None:
assert_expected(output.shape, input.shape)

def test_forward_gqa_kv_cache(
self, input: Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa
self, input: torch.Tensor, gqa_kv_cache: MultiHeadAttention, attn_params_gqa
) -> None:

_, _, _, max_seq_len = attn_params_gqa
Expand All @@ -279,7 +279,7 @@ def test_forward_gqa_kv_cache(
)
assert_expected(output.shape, input.shape)

def test_forward_mha(self, input: Tensor, mha: MultiHeadAttention) -> None:
def test_forward_mha(self, input: torch.Tensor, mha: MultiHeadAttention) -> None:
with torch.no_grad():
output = mha(input)
assert_expected(
Expand All @@ -288,7 +288,7 @@ def test_forward_mha(self, input: Tensor, mha: MultiHeadAttention) -> None:
assert_expected(output.shape, input.shape)

def test_forward_mha_kv_cache(
self, input: Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha
self, input: torch.Tensor, mha_kv_cache: MultiHeadAttention, attn_params_mha
) -> None:

_, _, _, max_seq_len = attn_params_mha
Expand All @@ -305,7 +305,7 @@ def test_forward_mha_kv_cache(
)
assert_expected(output.shape, input.shape)

def test_forward_mqa(self, input: Tensor, mqa: MultiHeadAttention) -> None:
def test_forward_mqa(self, input: torch.Tensor, mqa: MultiHeadAttention) -> None:
with torch.no_grad():
output = mqa(input)
assert_expected(
Expand All @@ -314,7 +314,7 @@ def test_forward_mqa(self, input: Tensor, mqa: MultiHeadAttention) -> None:
assert_expected(output.shape, input.shape)

def test_forward_mqa_kv_cache(
self, input: Tensor, mqa_kv_cache: MultiHeadAttention, attn_params_mqa
self, input: torch.Tensor, mqa_kv_cache: MultiHeadAttention, attn_params_mqa
) -> None:
_, _, _, max_seq_len = attn_params_mqa
_, seq_len, _ = input.shape
Expand All @@ -332,15 +332,15 @@ def test_forward_mqa_kv_cache(

def test_max_seq_len_exceeded(
self,
input_max_len_exceeded: Tensor,
input_max_len_exceeded: torch.Tensor,
gqa: MultiHeadAttention,
) -> None:
with pytest.raises(Exception):
_ = gqa(input_max_len_exceeded)

def test_batch_size_exceeded(
self,
input_max_bs_exceeded: Tensor,
input_max_bs_exceeded: torch.Tensor,
gqa_kv_cache: MultiHeadAttention,
) -> None:
with pytest.raises(Exception):
Expand Down
6 changes: 3 additions & 3 deletions tests/torchtune/modules/test_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch

from tests.test_utils import assert_expected, fixed_init_model
from torch import nn, Tensor
from torch import nn

from torchtune.modules import FeedForward
from torchtune.utils.seed import set_seed
Expand All @@ -32,7 +32,7 @@ def input_params(self) -> Tuple[int, int]:
return dim, hidden_dim

@pytest.fixture
def input(self, input_params: Tuple[int, int]) -> Tensor:
def input(self, input_params: Tuple[int, int]) -> torch.Tensor:
dim, _ = input_params
return torch.randn(1, dim)

Expand All @@ -49,7 +49,7 @@ def ffn(self, input_params: Tuple[int, int]) -> FeedForward:
ff.eval()
return ff

def test_forward(self, input: Tensor, ffn: FeedForward) -> None:
def test_forward(self, input: torch.Tensor, ffn: FeedForward) -> None:
with torch.no_grad():
x_out = ffn(input)
assert_expected(x_out.mean(), torch.tensor(251.5356), atol=1e-7, rtol=1e-3)
Expand Down
26 changes: 13 additions & 13 deletions tests/torchtune/modules/test_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tests.test_utils import assert_expected

from torch import nn, Tensor
from torch import nn

from torchtune.models.llama2 import llama2
from torchtune.models.llama2._component_builders import llama2_mlp
Expand Down Expand Up @@ -54,7 +54,7 @@ def input_params(self) -> Tuple[int, int, int]:
return batch_size, seq_len, embed_dim

@pytest.fixture
def input(self, input_params: Tuple[int, int, int]) -> Tensor:
def input(self, input_params: Tuple[int, int, int]) -> torch.Tensor:
batch_size, seq_len, embed_dim = input_params
return torch.randn(batch_size, seq_len, embed_dim)

Expand Down Expand Up @@ -100,7 +100,7 @@ def transformer_layer(
return transformer_layer

def test_forward(
self, input: Tensor, transformer_layer: TransformerSelfAttentionLayer
self, input: torch.Tensor, transformer_layer: TransformerSelfAttentionLayer
) -> None:
with torch.no_grad():
output = transformer_layer(input)
Expand All @@ -125,7 +125,7 @@ def input_params(self) -> Tuple[int, int, int, int]:
return batch_size, seq_len, encoder_seq_len, embed_dim

@pytest.fixture
def input(self, input_params: Tuple[int, int, int, int]) -> Tensor:
def input(self, input_params: Tuple[int, int, int, int]) -> torch.Tensor:
batch_size, seq_len, encoder_seq_len, embed_dim = input_params
rand_x = torch.randn(batch_size, seq_len, embed_dim)
rand_y = torch.randn(batch_size, 128, embed_dim)
Expand Down Expand Up @@ -185,7 +185,7 @@ def transformer_layer(

def test_forward(
self,
input: [Tensor, Tensor, Tensor],
input: [torch.Tensor, torch.Tensor, torch.Tensor],
transformer_layer: TransformerSelfAttentionLayer,
) -> None:
input_x, input_y, mask = input
Expand Down Expand Up @@ -215,7 +215,7 @@ def input_params(self) -> Tuple[int, int, int]:
return batch_size, seq_len, vocab_size

@pytest.fixture
def input(self, input_params: Tuple[int, int, int]) -> Tensor:
def input(self, input_params: Tuple[int, int, int]) -> torch.Tensor:
batch_size, seq_len, vocab_size = input_params
return torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))

Expand All @@ -234,7 +234,7 @@ def input_max_len_exceeded(
self,
input_params: Tuple[int, int, int],
decoder_params: Tuple[int, int, int, int, int, int],
) -> Tensor:
) -> torch.Tensor:
batch_size, seq_len, vocab_size = input_params
_, _, _, _, max_seq_len, _ = decoder_params
seq_len = max_seq_len + 1
Expand All @@ -245,7 +245,7 @@ def input_max_bs_exceeded(
self,
input_params: Tuple[int, int, int],
decoder_params: Tuple[int, int, int, int, int, int],
) -> Tensor:
) -> torch.Tensor:
batch_size, seq_len, vocab_size = input_params
_, _, _, _, max_seq_len, _ = decoder_params
batch_size = batch_size + 1
Expand Down Expand Up @@ -306,7 +306,7 @@ def decoder_with_kv_cache_enabled(

def test_forward(
self,
input: Tensor,
input: torch.Tensor,
input_params: Tuple[int, int, int],
decoder: TransformerDecoder,
) -> None:
Expand All @@ -318,15 +318,15 @@ def test_forward(

def test_max_seq_len_exceeded(
self,
input_max_len_exceeded: Tensor,
input_max_len_exceeded: torch.Tensor,
decoder: TransformerDecoder,
) -> None:
with pytest.raises(Exception):
output = decoder(input_max_len_exceeded)

def test_kv_cache(
self,
input: Tensor,
input: torch.Tensor,
decoder_with_kv_cache_enabled: TransformerDecoder,
decoder: TransformerDecoder,
) -> None:
Expand All @@ -340,7 +340,7 @@ def test_kv_cache(

def test_kv_cache_reset_values(
self,
input: Tensor,
input: torch.Tensor,
decoder_with_kv_cache_enabled: TransformerDecoder,
) -> None:
_, seq_len = input.shape
Expand Down Expand Up @@ -375,7 +375,7 @@ def test_kv_cache_reset_values_fails_when_not_enabled_first(

def test_kv_cache_batch_size_exceeded(
self,
input_max_bs_exceeded: Tensor,
input_max_bs_exceeded: torch.Tensor,
decoder_with_kv_cache_enabled: TransformerDecoder,
) -> None:
with pytest.raises(ValueError):
Expand Down
10 changes: 5 additions & 5 deletions torchtune/models/clip/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, embed_dim: int, tile_size: int, patch_size: int) -> None:
def forward(self, x: torch.Tensor, *args: Tuple[Any]) -> torch.Tensor:
"""
Args:
x (torch.Tensor): Tensor with shape (..., n_tokens, embed_dim)
x (torch.Tensor): torch.Tensor with shape (..., n_tokens, embed_dim)
*args (Tuple[Any]): Optional args.

Returns:
Expand Down Expand Up @@ -103,8 +103,8 @@ def __init__(
def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
aspect_ratio (torch.Tensor): Tensor with shape (bsz * n_imgs, 2),
x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2),
where aspect_ratio[k] represents the aspect ratio of the k^th image
of the batch before tile-cropping, e.g. aspect_ratio[k] = (2,1).
Returns:
Expand Down Expand Up @@ -169,8 +169,8 @@ def __init__(
def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
"""
args:
x (torch.Tensor): Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
aspect_ratio (torch.Tensor): Tensor with shape (bsz * n_imgs, 2),
x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim).
aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2),
representing the aspect ratio of the image before tile-cropping, e.g. (2,1).
returns:
torch.Tensor: The input tensor with added positional embeddings.
Expand Down
15 changes: 7 additions & 8 deletions torchtune/models/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torchtune.modules import KVCache

from torchtune.modules.transformer import _get_clones, TransformerSelfAttentionLayer
Expand Down Expand Up @@ -101,20 +100,20 @@ def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None:

def forward(
self,
tokens: Tensor,
tokens: torch.Tensor,
*,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
) -> Tensor:
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
tokens (Tensor): input tensor with shape [b x s]
mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask
tokens (torch.Tensor): input tensor with shape [b x s]
mask (Optional[torch.Tensor]): Optional boolean tensor which contains the attention mask
with shape [b x s x s]. This is applied after the query-key multiplication and
before the softmax. A value of True in row i and column j means token i attends
to token j. A value of False means token i does not attend to token j. If no
mask is specified, a causal mask is used by default. Default is None.
input_pos (Optional[Tensor]): Optional tensor which contains the position ids
input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
Expand Down
Loading
Loading