-
Notifications
You must be signed in to change notification settings - Fork 678
CLIP Text Encoder #1969
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CLIP Text Encoder #1969
Changes from 3 commits
f40d879
5dcf0d0
0a070af
8334463
d43f5b0
a3c90ac
1dbe939
e6b3d19
d501903
c4e700b
d5b7f98
5aa7c9f
c914aa0
69a5a16
5fe86ae
4c6ef70
3baea1c
bc867ab
10c1b0d
ec75cae
c215690
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,7 @@ dependencies = [ | |
"tqdm", | ||
"omegaconf", | ||
"psutil", | ||
"ftfy", | ||
|
||
# Multimodal | ||
"Pillow>=9.4.0", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import pytest | ||
|
||
from torchtune.models.clip._model_builders import clip_tokenizer | ||
|
||
|
||
class TestCLIPTokenizer: | ||
@pytest.fixture | ||
def tokenizer(self): | ||
return clip_tokenizer() | ||
calvinpelletier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
def test_tokenization(self, tokenizer): | ||
texts = [ | ||
"a cow jumping over the moon", | ||
"a helpful AI assistant", | ||
] | ||
correct_tokens = [ | ||
[49406, 320, 9706, 11476, 962, 518, 3293, 49407], | ||
[49406, 320, 12695, 2215, 6799, 49407], | ||
] | ||
for token_seq in correct_tokens: | ||
_pad_token_sequence(token_seq) | ||
tokens_tensor = tokenizer(texts) | ||
assert tokens_tensor.tolist() == correct_tokens | ||
|
||
def test_text_cleaning(self, tokenizer): | ||
calvinpelletier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
text = "(ง'⌣')ง" | ||
correct_tokens = [49406, 263, 33382, 6, 17848, 96, 19175, 33382, 49407] | ||
tokens = tokenizer.encode(text) | ||
assert tokens == correct_tokens | ||
|
||
def test_decoding(self, tokenizer): | ||
text = "this is torchtune" | ||
decoded_text = "<|startoftext|>this is torchtune <|endoftext|>" | ||
assert decoded_text == tokenizer.decode(tokenizer.encode(text)) | ||
|
||
|
||
def _pad_token_sequence(tokens, max_seq_len=77, pad_token=49407): | ||
while len(tokens) < max_seq_len: | ||
tokens.append(pad_token) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtune.models.convert_weights import get_mapped_key | ||
|
||
# state dict key mappings from HF's format to torchtune's format | ||
_FROM_HF = { | ||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight", | ||
"text_model.embeddings.position_embedding.weight": "position_embedding", | ||
"text_model.encoder.layers.{}.layer_norm1.weight": "encoder.{}.sa_norm.weight", | ||
"text_model.encoder.layers.{}.layer_norm1.bias": "encoder.{}.sa_norm.bias", | ||
"text_model.encoder.layers.{}.layer_norm2.weight": "encoder.{}.mlp_norm.weight", | ||
"text_model.encoder.layers.{}.layer_norm2.bias": "encoder.{}.mlp_norm.bias", | ||
"text_model.encoder.layers.{}.mlp.fc1.weight": "encoder.{}.mlp.w1.weight", | ||
"text_model.encoder.layers.{}.mlp.fc1.bias": "encoder.{}.mlp.w1.bias", | ||
"text_model.encoder.layers.{}.mlp.fc2.weight": "encoder.{}.mlp.w2.weight", | ||
"text_model.encoder.layers.{}.mlp.fc2.bias": "encoder.{}.mlp.w2.bias", | ||
"text_model.encoder.layers.{}.self_attn.q_proj.weight": "encoder.{}.attn.q_proj.weight", | ||
"text_model.encoder.layers.{}.self_attn.q_proj.bias": "encoder.{}.attn.q_proj.bias", | ||
"text_model.encoder.layers.{}.self_attn.k_proj.weight": "encoder.{}.attn.k_proj.weight", | ||
"text_model.encoder.layers.{}.self_attn.k_proj.bias": "encoder.{}.attn.k_proj.bias", | ||
"text_model.encoder.layers.{}.self_attn.v_proj.weight": "encoder.{}.attn.v_proj.weight", | ||
"text_model.encoder.layers.{}.self_attn.v_proj.bias": "encoder.{}.attn.v_proj.bias", | ||
"text_model.encoder.layers.{}.self_attn.out_proj.bias": "encoder.{}.attn.output_proj.bias", | ||
"text_model.encoder.layers.{}.self_attn.out_proj.weight": "encoder.{}.attn.output_proj.weight", | ||
"text_model.final_layer_norm.weight": "final_norm.weight", | ||
"text_model.final_layer_norm.bias": "final_norm.bias", | ||
} | ||
|
||
_IGNORE = { | ||
"logit_scale", | ||
"text_model.embeddings.position_ids", | ||
"text_projection.weight", | ||
"visual_projection.weight", | ||
} | ||
|
||
|
||
def clip_text_hf_to_tune(state_dict): | ||
converted_state_dict = {} | ||
for key, value in state_dict.items(): | ||
if key.startswith("vision_model.") or key in _IGNORE: | ||
continue | ||
new_key = get_mapped_key(key, _FROM_HF) | ||
converted_state_dict[new_key] = value | ||
return converted_state_dict |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from torch import nn, Tensor | ||
|
||
from torchtune.modules import ( | ||
FeedForward, | ||
MultiHeadAttention, | ||
TransformerSelfAttentionLayer, | ||
) | ||
from torchtune.modules.activations import QuickGELU | ||
|
||
|
||
class CLIPTextEncoder(nn.Module): | ||
""" | ||
Text encoder for CLIP. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mind briefly describing the architecture? probably a normal transformer, but uses a different MLP activation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's just a normal transformer. the MLP activation is still GELU, it's just a faster and less-precise version of it (tho not actually faster these days, this is a relic of the ancient year 2021) |
||
Args: | ||
vocab_size (int): size of the vocabulary, default 49408 | ||
max_seq_len (int): context size, default 77 | ||
embed_dim (int): embedding/model dimension size, default 768 | ||
num_heads (int): number of attention heads, default 12 | ||
num_layers (int): number of transformer layers, default 12 | ||
norm_eps (float): small value added to denominator for numerical stability, default 1e-5 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vocab_size: int = 49408, | ||
calvinpelletier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
max_seq_len: int = 77, | ||
embed_dim: int = 768, | ||
num_heads: int = 12, | ||
num_layers: int = 12, | ||
norm_eps: float = 1e-5, | ||
): | ||
super().__init__() | ||
self.max_seq_len = max_seq_len | ||
|
||
self.token_embedding = nn.Embedding(vocab_size, embed_dim) | ||
self.position_embedding = nn.Parameter(torch.empty(max_seq_len, embed_dim)) | ||
|
||
self.encoder = nn.Sequential( | ||
calvinpelletier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
*[ | ||
TransformerSelfAttentionLayer( | ||
attn=MultiHeadAttention( | ||
embed_dim=embed_dim, | ||
num_heads=num_heads, | ||
num_kv_heads=num_heads, | ||
head_dim=embed_dim // num_heads, | ||
q_proj=nn.Linear(embed_dim, embed_dim), | ||
k_proj=nn.Linear(embed_dim, embed_dim), | ||
v_proj=nn.Linear(embed_dim, embed_dim), | ||
output_proj=nn.Linear(embed_dim, embed_dim), | ||
), | ||
mlp=FeedForward( | ||
gate_proj=nn.Linear(embed_dim, embed_dim * 4), | ||
down_proj=nn.Linear(embed_dim * 4, embed_dim), | ||
activation=QuickGELU(), | ||
), | ||
sa_norm=nn.LayerNorm(embed_dim, eps=norm_eps), | ||
mlp_norm=nn.LayerNorm(embed_dim, eps=norm_eps), | ||
) | ||
for _ in range(num_layers) | ||
] | ||
) | ||
|
||
self.final_norm = nn.LayerNorm(embed_dim, eps=norm_eps) | ||
|
||
def forward(self, tokens: Tensor) -> Tensor: | ||
calvinpelletier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
""" | ||
Args: | ||
tokens (Tensor): input tensor with shape ``[b x s]`` | ||
Returns: | ||
Tensor: output tensor with shape [b x d] | ||
Raises: | ||
ValueError: if seq_len of tokens is bigger than max_seq_len | ||
Shape notation: | ||
- b: batch size | ||
- s: token sequence length | ||
- d: token embed dim | ||
""" | ||
# Input validation | ||
bsz, seq_len = tokens.shape | ||
if seq_len > self.max_seq_len: | ||
raise ValueError( | ||
f"seq_len ({seq_len}) of input tensor should be smaller " | ||
f"than max_seq_len ({self.max_seq_len})" | ||
) | ||
|
||
# Input embedding [b, s] -> [b, s, d] | ||
x = self.token_embedding(tokens) + self.position_embedding | ||
|
||
# Encoder [b, s, d] -> [b, s, d] | ||
x = self.encoder(x) | ||
x = self.final_norm(x) | ||
|
||
# Select the output of the EOS token for each encoding in the batch | ||
# [b, s, d] -> [b, d] | ||
eos_token_positions = tokens.argmax(dim=-1) | ||
calvinpelletier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
x = x[torch.arange(bsz, device=x.device), eos_token_positions] | ||
calvinpelletier marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
return x |
Uh oh!
There was an error while loading. Please reload this page.