Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"tqdm",
"omegaconf",
"psutil",
"ftfy",

# Multimodal
"Pillow>=9.4.0",
Expand Down
44 changes: 44 additions & 0 deletions tests/torchtune/models/clip/test_clip_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import pytest

from torchtune.models.clip._model_builders import clip_tokenizer


class TestCLIPTokenizer:
@pytest.fixture
def tokenizer(self):
return clip_tokenizer()

def test_tokenization(self, tokenizer):
texts = [
"a cow jumping over the moon",
"a helpful AI assistant",
]
correct_tokens = [
[49406, 320, 9706, 11476, 962, 518, 3293, 49407],
[49406, 320, 12695, 2215, 6799, 49407],
]
for token_seq in correct_tokens:
_pad_token_sequence(token_seq)
tokens_tensor = tokenizer(texts)
assert tokens_tensor.tolist() == correct_tokens

def test_text_cleaning(self, tokenizer):
text = "(ง'⌣')ง"
correct_tokens = [49406, 263, 33382, 6, 17848, 96, 19175, 33382, 49407]
tokens = tokenizer.encode(text)
assert tokens == correct_tokens

def test_decoding(self, tokenizer):
text = "this is torchtune"
decoded_text = "<|startoftext|>this is torchtune <|endoftext|>"
assert decoded_text == tokenizer.decode(tokenizer.encode(text))


def _pad_token_sequence(tokens, max_seq_len=77, pad_token=49407):
while len(tokens) < max_seq_len:
tokens.append(pad_token)
48 changes: 48 additions & 0 deletions torchtune/models/clip/_convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.models.convert_weights import get_mapped_key

# state dict key mappings from HF's format to torchtune's format
_FROM_HF = {
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
"text_model.embeddings.position_embedding.weight": "position_embedding",
"text_model.encoder.layers.{}.layer_norm1.weight": "encoder.{}.sa_norm.weight",
"text_model.encoder.layers.{}.layer_norm1.bias": "encoder.{}.sa_norm.bias",
"text_model.encoder.layers.{}.layer_norm2.weight": "encoder.{}.mlp_norm.weight",
"text_model.encoder.layers.{}.layer_norm2.bias": "encoder.{}.mlp_norm.bias",
"text_model.encoder.layers.{}.mlp.fc1.weight": "encoder.{}.mlp.w1.weight",
"text_model.encoder.layers.{}.mlp.fc1.bias": "encoder.{}.mlp.w1.bias",
"text_model.encoder.layers.{}.mlp.fc2.weight": "encoder.{}.mlp.w2.weight",
"text_model.encoder.layers.{}.mlp.fc2.bias": "encoder.{}.mlp.w2.bias",
"text_model.encoder.layers.{}.self_attn.q_proj.weight": "encoder.{}.attn.q_proj.weight",
"text_model.encoder.layers.{}.self_attn.q_proj.bias": "encoder.{}.attn.q_proj.bias",
"text_model.encoder.layers.{}.self_attn.k_proj.weight": "encoder.{}.attn.k_proj.weight",
"text_model.encoder.layers.{}.self_attn.k_proj.bias": "encoder.{}.attn.k_proj.bias",
"text_model.encoder.layers.{}.self_attn.v_proj.weight": "encoder.{}.attn.v_proj.weight",
"text_model.encoder.layers.{}.self_attn.v_proj.bias": "encoder.{}.attn.v_proj.bias",
"text_model.encoder.layers.{}.self_attn.out_proj.bias": "encoder.{}.attn.output_proj.bias",
"text_model.encoder.layers.{}.self_attn.out_proj.weight": "encoder.{}.attn.output_proj.weight",
"text_model.final_layer_norm.weight": "final_norm.weight",
"text_model.final_layer_norm.bias": "final_norm.bias",
}

_IGNORE = {
"logit_scale",
"text_model.embeddings.position_ids",
"text_projection.weight",
"visual_projection.weight",
}


def clip_text_hf_to_tune(state_dict):
converted_state_dict = {}
for key, value in state_dict.items():
if key.startswith("vision_model.") or key in _IGNORE:
continue
new_key = get_mapped_key(key, _FROM_HF)
converted_state_dict[new_key] = value
return converted_state_dict
61 changes: 60 additions & 1 deletion torchtune/models/clip/_model_builders.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,63 @@
from torchtune.models.clip._transforms import CLIPImageTransform
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path

from torchtune.models.clip._text_encoder import CLIPTextEncoder
from torchtune.models.clip._tokenizer import CLIPTokenizer
from torchtune.models.clip._transform import CLIPImageTransform
from torchtune.utils._download import TORCHTUNE_LOCAL_CACHE_FOLDER, download_file

CLIP_VOCAB_URL = "https://github.com/openai/CLIP/raw/refs/heads/main/clip/bpe_simple_vocab_16e6.txt.gz"


def clip_tokenizer(
vocab_path: Path = TORCHTUNE_LOCAL_CACHE_FOLDER / "clip_vocab.txt.gz",
download_if_missing: bool = True,
max_seq_len: int = 77,
truncate: bool = True,
) -> CLIPTokenizer:
"""
Builder for the CLIP text tokenizer.
Args:
vocab_path (pathlib.Path): Path to the CLIP vocab file
Default: '~/.cache/torchtune/clip_vocab.txt.gz'
download_if_missing (bool): Download the vocab file if it's not found
Default: True
max_seq_len (bool): Context length
Default: 77
truncate (bool): Truncate the token sequence if it exceeds max_seq_len (otherwise raises AssertionError)
Default: True
Returns:
CLIPTokenizer: Instantiation of the CLIP text tokenizer
"""
if not vocab_path.exists():
assert download_if_missing, f"Missing CLIP tokenizer vocab: {vocab_path}"
download_file(CLIP_VOCAB_URL, vocab_path)

return CLIPTokenizer(vocab_path, max_seq_len=max_seq_len, truncate=truncate)


def clip_text_encoder_large() -> CLIPTextEncoder:
"""
Builder for the CLIP text encoder for CLIP-ViT-L/14.
Returns:
CLIPTextEncoder: Instantiation of the CLIP text encoder
"""
return CLIPTextEncoder(
vocab_size=49408,
max_seq_len=77,
embed_dim=768,
num_heads=12,
num_layers=12,
norm_eps=1e-5,
)


def clip_vit_224_transform():
image_transform = CLIPImageTransform(
Expand Down
109 changes: 109 additions & 0 deletions torchtune/models/clip/_text_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn, Tensor

from torchtune.modules import (
FeedForward,
MultiHeadAttention,
TransformerSelfAttentionLayer,
)
from torchtune.modules.activations import QuickGELU


class CLIPTextEncoder(nn.Module):
"""
Text encoder for CLIP.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mind briefly describing the architecture? probably a normal transformer, but uses a different MLP activation?

Copy link
Contributor Author

@calvinpelletier calvinpelletier Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just a normal transformer. the MLP activation is still GELU, it's just a faster and less-precise version of it (tho not actually faster these days, this is a relic of the ancient year 2021)

Args:
vocab_size (int): size of the vocabulary, default 49408
max_seq_len (int): context size, default 77
embed_dim (int): embedding/model dimension size, default 768
num_heads (int): number of attention heads, default 12
num_layers (int): number of transformer layers, default 12
norm_eps (float): small value added to denominator for numerical stability, default 1e-5
"""

def __init__(
self,
vocab_size: int = 49408,
max_seq_len: int = 77,
embed_dim: int = 768,
num_heads: int = 12,
num_layers: int = 12,
norm_eps: float = 1e-5,
):
super().__init__()
self.max_seq_len = max_seq_len

self.token_embedding = nn.Embedding(vocab_size, embed_dim)
self.position_embedding = nn.Parameter(torch.empty(max_seq_len, embed_dim))

self.encoder = nn.Sequential(
*[
TransformerSelfAttentionLayer(
attn=MultiHeadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_heads,
head_dim=embed_dim // num_heads,
q_proj=nn.Linear(embed_dim, embed_dim),
k_proj=nn.Linear(embed_dim, embed_dim),
v_proj=nn.Linear(embed_dim, embed_dim),
output_proj=nn.Linear(embed_dim, embed_dim),
),
mlp=FeedForward(
gate_proj=nn.Linear(embed_dim, embed_dim * 4),
down_proj=nn.Linear(embed_dim * 4, embed_dim),
activation=QuickGELU(),
),
sa_norm=nn.LayerNorm(embed_dim, eps=norm_eps),
mlp_norm=nn.LayerNorm(embed_dim, eps=norm_eps),
)
for _ in range(num_layers)
]
)

self.final_norm = nn.LayerNorm(embed_dim, eps=norm_eps)

def forward(self, tokens: Tensor) -> Tensor:
"""
Args:
tokens (Tensor): input tensor with shape ``[b x s]``
Returns:
Tensor: output tensor with shape [b x d]
Raises:
ValueError: if seq_len of tokens is bigger than max_seq_len
Shape notation:
- b: batch size
- s: token sequence length
- d: token embed dim
"""
# Input validation
bsz, seq_len = tokens.shape
if seq_len > self.max_seq_len:
raise ValueError(
f"seq_len ({seq_len}) of input tensor should be smaller "
f"than max_seq_len ({self.max_seq_len})"
)

# Input embedding [b, s] -> [b, s, d]
x = self.token_embedding(tokens) + self.position_embedding

# Encoder [b, s, d] -> [b, s, d]
x = self.encoder(x)
x = self.final_norm(x)

# Select the output of the EOS token for each encoding in the batch
# [b, s, d] -> [b, d]
eos_token_positions = tokens.argmax(dim=-1)
x = x[torch.arange(bsz, device=x.device), eos_token_positions]

return x
Loading
Loading