From f40d879c2ed6fbcf749fb55e58157bce58435af2 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Thu, 7 Nov 2024 13:16:20 -0800 Subject: [PATCH 01/16] CLIP tokenizer and text encoder --- pyproject.toml | 1 + .../models/clip/test_clip_tokenizer.py | 193 ++++++++++++++++ torchtune/models/clip/_convert_weights.py | 49 ++++ torchtune/models/clip/_model_builders.py | 57 ++++- torchtune/models/clip/_text_encoder.py | 109 +++++++++ torchtune/models/clip/_tokenizer.py | 210 ++++++++++++++++++ torchtune/modules/activations.py | 17 ++ .../training/checkpointing/_checkpointer.py | 5 + torchtune/training/checkpointing/_utils.py | 2 + torchtune/utils/_download.py | 49 ++++ 10 files changed, 691 insertions(+), 1 deletion(-) create mode 100644 tests/torchtune/models/clip/test_clip_tokenizer.py create mode 100644 torchtune/models/clip/_convert_weights.py create mode 100644 torchtune/models/clip/_text_encoder.py create mode 100644 torchtune/models/clip/_tokenizer.py create mode 100644 torchtune/modules/activations.py create mode 100644 torchtune/utils/_download.py diff --git a/pyproject.toml b/pyproject.toml index c2920ff4d3..45afd3c9cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "tqdm", "omegaconf", "psutil", + "ftfy", # Multimodal "Pillow>=9.4.0", diff --git a/tests/torchtune/models/clip/test_clip_tokenizer.py b/tests/torchtune/models/clip/test_clip_tokenizer.py new file mode 100644 index 0000000000..463828d1cb --- /dev/null +++ b/tests/torchtune/models/clip/test_clip_tokenizer.py @@ -0,0 +1,193 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +from torchtune.models.clip._model_builders import clip_tokenizer + + +class TestCLIPTokenizer: + @pytest.fixture + def tokenizer(self): + return clip_tokenizer() + + def test_tokenization(self, tokenizer): + texts = [ + "a cow jumping over the moon", + "a helpful AI assistant", + ] + correct_tokens = [ + [ + 49406, + 320, + 9706, + 11476, + 962, + 518, + 3293, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + ], + [ + 49406, + 320, + 12695, + 2215, + 6799, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + 49407, + ], + ] + tokens_tensor = tokenizer(texts) + assert tokens_tensor.tolist() == correct_tokens + + def test_text_cleaning(self, tokenizer): + text = "(ง'⌣')ง" + correct_tokens = [49406, 263, 33382, 6, 17848, 96, 19175, 33382, 49407] + tokens = tokenizer.encode(text) + assert tokens == correct_tokens + + def test_decoding(self, tokenizer): + text = "this is torchtune" + decoded_text = "<|startoftext|>this is torchtune <|endoftext|>" + assert decoded_text == tokenizer.decode(tokenizer.encode(text)) diff --git a/torchtune/models/clip/_convert_weights.py b/torchtune/models/clip/_convert_weights.py new file mode 100644 index 0000000000..99235549c2 --- /dev/null +++ b/torchtune/models/clip/_convert_weights.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.models.convert_weights import get_mapped_key + +# state dict key mappings from HF's format to torchtune's format +_FROM_HF = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embedding", + "text_model.encoder.layers.{}.layer_norm1.weight": "encoder.{}.sa_norm.weight", + "text_model.encoder.layers.{}.layer_norm1.bias": "encoder.{}.sa_norm.bias", + "text_model.encoder.layers.{}.layer_norm2.weight": "encoder.{}.mlp_norm.weight", + "text_model.encoder.layers.{}.layer_norm2.bias": "encoder.{}.mlp_norm.bias", + "text_model.encoder.layers.{}.mlp.fc1.weight": "encoder.{}.mlp.w1.weight", + "text_model.encoder.layers.{}.mlp.fc1.bias": "encoder.{}.mlp.w1.bias", + "text_model.encoder.layers.{}.mlp.fc2.weight": "encoder.{}.mlp.w2.weight", + "text_model.encoder.layers.{}.mlp.fc2.bias": "encoder.{}.mlp.w2.bias", + "text_model.encoder.layers.{}.self_attn.q_proj.weight": "encoder.{}.attn.q_proj.weight", + "text_model.encoder.layers.{}.self_attn.q_proj.bias": "encoder.{}.attn.q_proj.bias", + "text_model.encoder.layers.{}.self_attn.k_proj.weight": "encoder.{}.attn.k_proj.weight", + "text_model.encoder.layers.{}.self_attn.k_proj.bias": "encoder.{}.attn.k_proj.bias", + "text_model.encoder.layers.{}.self_attn.v_proj.weight": "encoder.{}.attn.v_proj.weight", + "text_model.encoder.layers.{}.self_attn.v_proj.bias": "encoder.{}.attn.v_proj.bias", + "text_model.encoder.layers.{}.self_attn.out_proj.bias": "encoder.{}.attn.output_proj.bias", + "text_model.encoder.layers.{}.self_attn.out_proj.weight": "encoder.{}.attn.output_proj.weight", + "text_model.final_layer_norm.weight": "final_norm.weight", + "text_model.final_layer_norm.bias": "final_norm.bias", +} + +_IGNORE = { + "logit_scale", + "text_model.embeddings.position_ids", + "text_projection.weight", + "visual_projection.weight", +} + + +def clip_text_hf_to_tune(state_dict): + converted_state_dict = {} + for key, value in state_dict.items(): + # print(key) + if key.startswith("vision_model.") or key in _IGNORE: + continue + new_key = get_mapped_key(key, _FROM_HF) + converted_state_dict[new_key] = value + return converted_state_dict diff --git a/torchtune/models/clip/_model_builders.py b/torchtune/models/clip/_model_builders.py index 61c01d1c51..7d7aee1891 100644 --- a/torchtune/models/clip/_model_builders.py +++ b/torchtune/models/clip/_model_builders.py @@ -1,4 +1,59 @@ -from torchtune.models.clip._transforms import CLIPImageTransform +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torchtune.models.clip._transform import CLIPImageTransform +from pathlib import Path + +from torchtune.utils._download import download_file, TORCHTUNE_LOCAL_CACHE_FOLDER +from torchtune.models.clip._tokenizer import CLIPTokenizer +from torchtune.models.clip._text_encoder import CLIPTextEncoder + + +CLIP_VOCAB_URL = 'https://github.com/openai/CLIP/raw/refs/heads/main/clip/bpe_simple_vocab_16e6.txt.gz' + + +def clip_tokenizer(vocab_path: Path = TORCHTUNE_LOCAL_CACHE_FOLDER / 'clip_vocab.txt.gz', download_if_missing: bool = True, max_seq_len: int = 77, truncate: bool = True) -> CLIPTokenizer: + """ + Builder for the CLIP text tokenizer. + + Args: + vocab_path (pathlib.Path): Path to the CLIP vocab file + Default: '~/.cache/torchtune/clip_vocab.txt.gz' + download_if_missing (bool): Download the vocab file if it's not found + Default: True + max_seq_len (bool): Context length + Default: 77 + truncate (bool): Truncate the token sequence if it exceeds max_seq_len (otherwise raises AssertionError) + Default: True + + Returns: + CLIPTokenizer: Instantiation of the CLIP text tokenizer + """ + if not vocab_path.exists(): + assert download_if_missing, f'Missing CLIP tokenizer vocab: {vocab_path}' + download_file(CLIP_VOCAB_URL, vocab_path) + + return CLIPTokenizer(vocab_path, max_seq_len=max_seq_len, truncate=truncate) + + +def clip_text_encoder_large() -> CLIPTextEncoder: + """ + Builder for the CLIP text encoder for CLIP-ViT-L/14. + + Returns: + CLIPTextEncoder: Instantiation of the CLIP text encoder + """ + return CLIPTextEncoder( + vocab_size=49408, + max_seq_len=77, + embed_dim=768, + num_heads=12, + num_layers=12, + norm_eps=1e-5, + ) + def clip_vit_224_transform(): image_transform = CLIPImageTransform( diff --git a/torchtune/models/clip/_text_encoder.py b/torchtune/models/clip/_text_encoder.py new file mode 100644 index 0000000000..b08ab89a5d --- /dev/null +++ b/torchtune/models/clip/_text_encoder.py @@ -0,0 +1,109 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn, Tensor + +from torchtune.modules import ( + FeedForward, + MultiHeadAttention, + TransformerSelfAttentionLayer, +) +from torchtune.modules.activations import QuickGELU + + +class CLIPTextEncoder(nn.Module): + """ + Text encoder for CLIP. + + Args: + vocab_size (int): size of the vocabulary, default 49408 + max_seq_len (int): context size, default 77 + embed_dim (int): embedding/model dimension size, default 768 + num_heads (int): number of attention heads, default 12 + num_layers (int): number of transformer layers, default 12 + norm_eps (float): small value added to denominator for numerical stability, default 1e-5 + """ + + def __init__( + self, + vocab_size: int = 49408, + max_seq_len: int = 77, + embed_dim: int = 768, + num_heads: int = 12, + num_layers: int = 12, + norm_eps: float = 1e-5, + ): + super().__init__() + self.max_seq_len = max_seq_len + + self.token_embedding = nn.Embedding(vocab_size, embed_dim) + self.position_embedding = nn.Parameter(torch.empty(max_seq_len, embed_dim)) + + self.encoder = nn.Sequential( + *[ + TransformerSelfAttentionLayer( + attn=MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=embed_dim // num_heads, + q_proj=nn.Linear(embed_dim, embed_dim), + k_proj=nn.Linear(embed_dim, embed_dim), + v_proj=nn.Linear(embed_dim, embed_dim), + output_proj=nn.Linear(embed_dim, embed_dim), + ), + mlp=FeedForward( + gate_proj=nn.Linear(embed_dim, embed_dim * 4), + down_proj=nn.Linear(embed_dim * 4, embed_dim), + activation=QuickGELU(), + ), + sa_norm=nn.LayerNorm(embed_dim, eps=norm_eps), + mlp_norm=nn.LayerNorm(embed_dim, eps=norm_eps), + ) + for _ in range(num_layers) + ] + ) + + self.final_norm = nn.LayerNorm(embed_dim, eps=norm_eps) + + def forward(self, tokens: Tensor) -> Tensor: + """ + Args: + tokens (Tensor): input tensor with shape ``[b x s]`` + + Returns: + Tensor: output tensor with shape [b x d] + + Raises: + ValueError: if seq_len of tokens is bigger than max_seq_len + + Shape notation: + - b: batch size + - s: token sequence length + - d: token embed dim + """ + # Input validation + bsz, seq_len = tokens.shape + if seq_len > self.max_seq_len: + raise ValueError( + f"seq_len ({seq_len}) of input tensor should be smaller " + f"than max_seq_len ({self.max_seq_len})" + ) + + # Input embedding [b, s] -> [b, s, d] + x = self.token_embedding(tokens) + self.position_embedding + + # Encoder [b, s, d] -> [b, s, d] + x = self.encoder(x) + x = self.final_norm(x) + + # Select the output of the EOS token for each encoding in the batch + # [b, s, d] -> [b, d] + eos_token_positions = tokens.argmax(dim=-1) + x = x[torch.arange(bsz, device=x.device), eos_token_positions] + + return x diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py new file mode 100644 index 0000000000..320eff1808 --- /dev/null +++ b/torchtune/models/clip/_tokenizer.py @@ -0,0 +1,210 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import gzip +import html +from pathlib import Path +from typing import List + +import ftfy +import regex as re +import torch +from torchtune.modules.tokenizers._utils import BaseTokenizer + + +class CLIPTokenizer(BaseTokenizer): + """ + Text tokenizer for CLIP. + + Based on the official implementation here: + https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py + + Args: + vocab_path (Path): the path to the CLIP vocab file + max_seq_len (int): the context length (all CLIP models use 77) + truncate (bool): whether to truncate the text when longer than max_seq_len + """ + + def __init__(self, vocab_path: Path, max_seq_len: int = 77, truncate: bool = True): + self.max_seq_len = max_seq_len + self.truncate = truncate + + self.byte_encoder = _bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + with gzip.open(vocab_path) as f: + merges = f.read().decode("utf-8").split("\n") + merges = [tuple(merge.split()) for merge in merges[1:48895]] + + vocab = list(self.byte_encoder.values()) + vocab.extend([v + "" for v in vocab]) + vocab.extend(["".join(merge) for merge in merges]) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + + self.encoder = {word: i for i, word in enumerate(vocab)} + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = {merge: i for i, merge in enumerate(merges)} + + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + self.sot_token = self.encoder["<|startoftext|>"] + self.eot_token = self.encoder["<|endoftext|>"] + self.pad_token = self.eot_token + + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + + def __call__(self, texts: List[str]) -> torch.Tensor: + """ + Returns a Tensor with the tokenized representation of given input strings + + Args: + texts (List[str]): list of input strings to tokenize + + Returns: + torch.Tensor: int tensor with shape [len(texts), max_seq_len] + """ + assert isinstance(texts, list) + result = torch.full( + (len(texts), self.max_seq_len), self.pad_token, dtype=torch.int + ) + for i, text in enumerate(texts): + tokens = self.encode(text) + result[i, : len(tokens)] = torch.tensor(tokens) + return result + + def encode(self, text: str) -> List[int]: + """ + Given a string, return the encoded list of token ids. + + Args: + text (str): The text to encode. + + Returns: + List[int]: The encoded list of token ids. + """ + text = _clean_text(text).lower() + + tokens = [self.sot_token] + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + tokens.extend( + self.encoder[bpe_token] for bpe_token in self._bpe(token).split(" ") + ) + if len(tokens) >= self.max_seq_len: + break + tokens.append(self.eot_token) + + if len(tokens) > self.max_seq_len: + assert self.truncate, ( + "Tokenized text is larger than the maximum sequence length but " + "truncate is set to False." + ) + tokens = tokens[: self.max_seq_len] + tokens[-1] = self.eot_token + + return tokens + + def decode(self, tokens: List[int]) -> str: + """ + Given a list of token ids, return the decoded text, optionally including special tokens. + + Args: + tokens (List[int]): The list of token ids to decode. + + Returns: + str: The decoded text. + """ + text = "".join([self.decoder[token] for token in tokens]) + return ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + + def _bpe(self, token): + if token in self.cache: + return self.cache[token] + + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = _get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except ValueError: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + word = tuple(new_word) + if len(word) == 1: + break + pairs = _get_pairs(word) + + word = " ".join(word) + self.cache[token] = word + return word + + +def _bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def _get_pairs(word): + """ + Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def _clean_text(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)).strip() + return re.sub(r"\s+", " ", text).strip() diff --git a/torchtune/modules/activations.py b/torchtune/modules/activations.py new file mode 100644 index 0000000000..2b43908df0 --- /dev/null +++ b/torchtune/modules/activations.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn + + +class QuickGELU(nn.Module): + """ + Fast approximation of GELU. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index 19141d626a..238b9d0718 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -16,6 +16,7 @@ from torchtune import training from torchtune.models import convert_weights +from torchtune.models.clip._convert_weights import clip_text_hf_to_tune from torchtune.models.phi3._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf from torchtune.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf @@ -488,6 +489,10 @@ def load_checkpoint(self) -> Dict[str, Any]: "supported_aspect_ratios", None ), ) + elif self._model_type == ModelType.CLIP_TEXT: + converted_state_dict[training.MODEL_KEY] = clip_text_hf_to_tune( + merged_state_dict, + ) else: converted_state_dict[training.MODEL_KEY] = convert_weights.hf_to_tune( merged_state_dict, diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 2d353b007c..0515cba668 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -55,6 +55,7 @@ class ModelType(Enum): to a single class for reward modelling. See :func:`~torchtune.models.mistral.mistral_reward_7b` or :func:`~torchtune.models.llama2.llama2_reward_7b` QWEN2 (str): Qwen2 family of models. See :func:`~torchtune.models.qwen2.qwen2` + CLIP_TEXT (str): CLIP text encoder. See :func:`~torchtune.models.clip.clip_text_encoder_large` Example: >>> # Usage in a checkpointer class @@ -73,6 +74,7 @@ class ModelType(Enum): PHI3_MINI: str = "phi3_mini" REWARD: str = "reward" QWEN2: str = "qwen2" + CLIP_TEXT: str = "clip_text" class FormattedCheckpointFiles: diff --git a/torchtune/utils/_download.py b/torchtune/utils/_download.py new file mode 100644 index 0000000000..8cc5ef0344 --- /dev/null +++ b/torchtune/utils/_download.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from pathlib import Path + +import requests +from tqdm import tqdm + + +TORCHTUNE_LOCAL_CACHE_FOLDER = Path("~/.cache/torchtune").expanduser() + + +def download_file(url: str, save_path: Path, chunk_size: int = 8192): + """ + Download a file with progress bar. + + Args: + url (str): URL of the file to download + save_path (Path): Path where the file should be saved + chunk_size (int): Size of chunks to download at a time + + Raises: + requests.RequestException: if download fails + """ + if save_path.parent is not None and not save_path.parent.exists(): + save_path.parent.mkdir(parents=True) + + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + + with open(save_path, "wb") as f, tqdm( + desc=save_path.name, + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(chunk_size=chunk_size): + size = f.write(data) + progress_bar.update(size) + + except requests.RequestException as e: + print(f"Error downloading file: {e}") + raise From 0a070af0c7961cd2cd39000222bf4f9f70e00ba8 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Thu, 7 Nov 2024 14:48:09 -0800 Subject: [PATCH 02/16] formatting --- .../models/clip/test_clip_tokenizer.py | 167 +----------------- torchtune/models/clip/_convert_weights.py | 1 - torchtune/models/clip/_model_builders.py | 20 ++- torchtune/models/clip/_tokenizer.py | 1 + torchtune/utils/_download.py | 1 - 5 files changed, 22 insertions(+), 168 deletions(-) diff --git a/tests/torchtune/models/clip/test_clip_tokenizer.py b/tests/torchtune/models/clip/test_clip_tokenizer.py index 463828d1cb..9e210193a3 100644 --- a/tests/torchtune/models/clip/test_clip_tokenizer.py +++ b/tests/torchtune/models/clip/test_clip_tokenizer.py @@ -19,165 +19,11 @@ def test_tokenization(self, tokenizer): "a helpful AI assistant", ] correct_tokens = [ - [ - 49406, - 320, - 9706, - 11476, - 962, - 518, - 3293, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - ], - [ - 49406, - 320, - 12695, - 2215, - 6799, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - 49407, - ], + [49406, 320, 9706, 11476, 962, 518, 3293, 49407], + [49406, 320, 12695, 2215, 6799, 49407], ] + for token_seq in correct_tokens: + _pad_token_sequence(token_seq) tokens_tensor = tokenizer(texts) assert tokens_tensor.tolist() == correct_tokens @@ -191,3 +37,8 @@ def test_decoding(self, tokenizer): text = "this is torchtune" decoded_text = "<|startoftext|>this is torchtune <|endoftext|>" assert decoded_text == tokenizer.decode(tokenizer.encode(text)) + + +def _pad_token_sequence(tokens, max_seq_len=77, pad_token=49407): + while len(tokens) < max_seq_len: + tokens.append(pad_token) diff --git a/torchtune/models/clip/_convert_weights.py b/torchtune/models/clip/_convert_weights.py index 99235549c2..d1d17a11b5 100644 --- a/torchtune/models/clip/_convert_weights.py +++ b/torchtune/models/clip/_convert_weights.py @@ -41,7 +41,6 @@ def clip_text_hf_to_tune(state_dict): converted_state_dict = {} for key, value in state_dict.items(): - # print(key) if key.startswith("vision_model.") or key in _IGNORE: continue new_key = get_mapped_key(key, _FROM_HF) diff --git a/torchtune/models/clip/_model_builders.py b/torchtune/models/clip/_model_builders.py index 7d7aee1891..5296356ebb 100644 --- a/torchtune/models/clip/_model_builders.py +++ b/torchtune/models/clip/_model_builders.py @@ -3,18 +3,22 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torchtune.models.clip._transform import CLIPImageTransform from pathlib import Path -from torchtune.utils._download import download_file, TORCHTUNE_LOCAL_CACHE_FOLDER -from torchtune.models.clip._tokenizer import CLIPTokenizer from torchtune.models.clip._text_encoder import CLIPTextEncoder +from torchtune.models.clip._tokenizer import CLIPTokenizer +from torchtune.models.clip._transform import CLIPImageTransform +from torchtune.utils._download import TORCHTUNE_LOCAL_CACHE_FOLDER, download_file - -CLIP_VOCAB_URL = 'https://github.com/openai/CLIP/raw/refs/heads/main/clip/bpe_simple_vocab_16e6.txt.gz' +CLIP_VOCAB_URL = "https://github.com/openai/CLIP/raw/refs/heads/main/clip/bpe_simple_vocab_16e6.txt.gz" -def clip_tokenizer(vocab_path: Path = TORCHTUNE_LOCAL_CACHE_FOLDER / 'clip_vocab.txt.gz', download_if_missing: bool = True, max_seq_len: int = 77, truncate: bool = True) -> CLIPTokenizer: +def clip_tokenizer( + vocab_path: Path = TORCHTUNE_LOCAL_CACHE_FOLDER / "clip_vocab.txt.gz", + download_if_missing: bool = True, + max_seq_len: int = 77, + truncate: bool = True, +) -> CLIPTokenizer: """ Builder for the CLIP text tokenizer. @@ -32,9 +36,9 @@ def clip_tokenizer(vocab_path: Path = TORCHTUNE_LOCAL_CACHE_FOLDER / 'clip_vocab CLIPTokenizer: Instantiation of the CLIP text tokenizer """ if not vocab_path.exists(): - assert download_if_missing, f'Missing CLIP tokenizer vocab: {vocab_path}' + assert download_if_missing, f"Missing CLIP tokenizer vocab: {vocab_path}" download_file(CLIP_VOCAB_URL, vocab_path) - + return CLIPTokenizer(vocab_path, max_seq_len=max_seq_len, truncate=truncate) diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py index 320eff1808..a2035bc228 100644 --- a/torchtune/models/clip/_tokenizer.py +++ b/torchtune/models/clip/_tokenizer.py @@ -11,6 +11,7 @@ import ftfy import regex as re import torch + from torchtune.modules.tokenizers._utils import BaseTokenizer diff --git a/torchtune/utils/_download.py b/torchtune/utils/_download.py index 8cc5ef0344..738e911ad0 100644 --- a/torchtune/utils/_download.py +++ b/torchtune/utils/_download.py @@ -8,7 +8,6 @@ import requests from tqdm import tqdm - TORCHTUNE_LOCAL_CACHE_FOLDER = Path("~/.cache/torchtune").expanduser() From 833446357fb01bc34e1d60c5ffb7d1dca7bcb8f1 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Mon, 11 Nov 2024 13:54:01 -0800 Subject: [PATCH 03/16] switching to hf vocab file --- .../models/clip/test_clip_tokenizer.py | 41 +++++++++++----- torchtune/models/clip/_model_builders.py | 19 ++------ torchtune/models/clip/_text_encoder.py | 2 +- torchtune/models/clip/_tokenizer.py | 19 +++++--- torchtune/utils/_download.py | 48 ------------------- 5 files changed, 46 insertions(+), 83 deletions(-) delete mode 100644 torchtune/utils/_download.py diff --git a/tests/torchtune/models/clip/test_clip_tokenizer.py b/tests/torchtune/models/clip/test_clip_tokenizer.py index 9e210193a3..d07f9ef5b9 100644 --- a/tests/torchtune/models/clip/test_clip_tokenizer.py +++ b/tests/torchtune/models/clip/test_clip_tokenizer.py @@ -5,13 +5,14 @@ # LICENSE file in the root directory of this source tree. import pytest +from tests.common import ASSETS from torchtune.models.clip._model_builders import clip_tokenizer class TestCLIPTokenizer: @pytest.fixture def tokenizer(self): - return clip_tokenizer() + return clip_tokenizer(ASSETS / "tiny_bpe_merges.txt") def test_tokenization(self, tokenizer): texts = [ @@ -19,26 +20,42 @@ def test_tokenization(self, tokenizer): "a helpful AI assistant", ] correct_tokens = [ - [49406, 320, 9706, 11476, 962, 518, 3293, 49407], - [49406, 320, 12695, 2215, 6799, 49407], + _pad( + [ + 2416, + 320, + 66, + 78, + 342, + 73, + 669, + 79, + 515, + 326, + 1190, + 337, + 673, + 324, + 76, + 819, + 333, + 2417, + ] + ), + _pad( + [2416, 320, 516, 75, 79, 69, 84, 331, 64, 328, 813, 667, 540, 339, 2417] + ), ] - for token_seq in correct_tokens: - _pad_token_sequence(token_seq) tokens_tensor = tokenizer(texts) assert tokens_tensor.tolist() == correct_tokens - def test_text_cleaning(self, tokenizer): - text = "(ง'⌣')ง" - correct_tokens = [49406, 263, 33382, 6, 17848, 96, 19175, 33382, 49407] - tokens = tokenizer.encode(text) - assert tokens == correct_tokens - def test_decoding(self, tokenizer): text = "this is torchtune" decoded_text = "<|startoftext|>this is torchtune <|endoftext|>" assert decoded_text == tokenizer.decode(tokenizer.encode(text)) -def _pad_token_sequence(tokens, max_seq_len=77, pad_token=49407): +def _pad(tokens, max_seq_len=77, pad_token=2417): while len(tokens) < max_seq_len: tokens.append(pad_token) + return tokens diff --git a/torchtune/models/clip/_model_builders.py b/torchtune/models/clip/_model_builders.py index 5296356ebb..7949f65829 100644 --- a/torchtune/models/clip/_model_builders.py +++ b/torchtune/models/clip/_model_builders.py @@ -3,19 +3,15 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from pathlib import Path +from os import PathLike from torchtune.models.clip._text_encoder import CLIPTextEncoder from torchtune.models.clip._tokenizer import CLIPTokenizer from torchtune.models.clip._transform import CLIPImageTransform -from torchtune.utils._download import TORCHTUNE_LOCAL_CACHE_FOLDER, download_file - -CLIP_VOCAB_URL = "https://github.com/openai/CLIP/raw/refs/heads/main/clip/bpe_simple_vocab_16e6.txt.gz" def clip_tokenizer( - vocab_path: Path = TORCHTUNE_LOCAL_CACHE_FOLDER / "clip_vocab.txt.gz", - download_if_missing: bool = True, + merges_path: PathLike, max_seq_len: int = 77, truncate: bool = True, ) -> CLIPTokenizer: @@ -23,10 +19,7 @@ def clip_tokenizer( Builder for the CLIP text tokenizer. Args: - vocab_path (pathlib.Path): Path to the CLIP vocab file - Default: '~/.cache/torchtune/clip_vocab.txt.gz' - download_if_missing (bool): Download the vocab file if it's not found - Default: True + merges_path (PathLike): Path to the CLIP merges file max_seq_len (bool): Context length Default: 77 truncate (bool): Truncate the token sequence if it exceeds max_seq_len (otherwise raises AssertionError) @@ -35,11 +28,7 @@ def clip_tokenizer( Returns: CLIPTokenizer: Instantiation of the CLIP text tokenizer """ - if not vocab_path.exists(): - assert download_if_missing, f"Missing CLIP tokenizer vocab: {vocab_path}" - download_file(CLIP_VOCAB_URL, vocab_path) - - return CLIPTokenizer(vocab_path, max_seq_len=max_seq_len, truncate=truncate) + return CLIPTokenizer(merges_path, max_seq_len=max_seq_len, truncate=truncate) def clip_text_encoder_large() -> CLIPTextEncoder: diff --git a/torchtune/models/clip/_text_encoder.py b/torchtune/models/clip/_text_encoder.py index b08ab89a5d..b8fb102349 100644 --- a/torchtune/models/clip/_text_encoder.py +++ b/torchtune/models/clip/_text_encoder.py @@ -104,6 +104,6 @@ def forward(self, tokens: Tensor) -> Tensor: # Select the output of the EOS token for each encoding in the batch # [b, s, d] -> [b, d] eos_token_positions = tokens.argmax(dim=-1) - x = x[torch.arange(bsz, device=x.device), eos_token_positions] + x = x.take_along_dim(eos_token_positions.view(-1, 1, 1), dim=1).squeeze(dim=1) return x diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py index a2035bc228..ff80ee61b0 100644 --- a/torchtune/models/clip/_tokenizer.py +++ b/torchtune/models/clip/_tokenizer.py @@ -3,9 +3,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import gzip import html -from pathlib import Path +from os import PathLike from typing import List import ftfy @@ -23,21 +22,27 @@ class CLIPTokenizer(BaseTokenizer): https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py Args: - vocab_path (Path): the path to the CLIP vocab file + merges_path (PathLike): the path to the CLIP merges file max_seq_len (int): the context length (all CLIP models use 77) truncate (bool): whether to truncate the text when longer than max_seq_len """ - def __init__(self, vocab_path: Path, max_seq_len: int = 77, truncate: bool = True): + def __init__( + self, merges_path: PathLike, max_seq_len: int = 77, truncate: bool = True + ): self.max_seq_len = max_seq_len self.truncate = truncate self.byte_encoder = _bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - with gzip.open(vocab_path) as f: - merges = f.read().decode("utf-8").split("\n") - merges = [tuple(merge.split()) for merge in merges[1:48895]] + merges = [] + with open(merges_path, encoding="utf-8") as f: + for i, line in enumerate(f): + line = line.strip() + if (i == 0 and line.startswith("#version:")) or not line: + continue + merges.append(tuple(line.split())) vocab = list(self.byte_encoder.values()) vocab.extend([v + "" for v in vocab]) diff --git a/torchtune/utils/_download.py b/torchtune/utils/_download.py deleted file mode 100644 index 738e911ad0..0000000000 --- a/torchtune/utils/_download.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from pathlib import Path - -import requests -from tqdm import tqdm - -TORCHTUNE_LOCAL_CACHE_FOLDER = Path("~/.cache/torchtune").expanduser() - - -def download_file(url: str, save_path: Path, chunk_size: int = 8192): - """ - Download a file with progress bar. - - Args: - url (str): URL of the file to download - save_path (Path): Path where the file should be saved - chunk_size (int): Size of chunks to download at a time - - Raises: - requests.RequestException: if download fails - """ - if save_path.parent is not None and not save_path.parent.exists(): - save_path.parent.mkdir(parents=True) - - try: - response = requests.get(url, stream=True) - response.raise_for_status() - - total_size = int(response.headers.get("content-length", 0)) - - with open(save_path, "wb") as f, tqdm( - desc=save_path.name, - total=total_size, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as progress_bar: - for data in response.iter_content(chunk_size=chunk_size): - size = f.write(data) - progress_bar.update(size) - - except requests.RequestException as e: - print(f"Error downloading file: {e}") - raise From d43f5b05df2d4a0879f7c5094c961037702dcdc8 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Mon, 11 Nov 2024 14:02:08 -0800 Subject: [PATCH 04/16] remove dependency on ftfy --- pyproject.toml | 1 - torchtune/models/clip/_tokenizer.py | 10 +--------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 45afd3c9cb..c2920ff4d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ dependencies = [ "tqdm", "omegaconf", "psutil", - "ftfy", # Multimodal "Pillow>=9.4.0", diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py index ff80ee61b0..5d1f239466 100644 --- a/torchtune/models/clip/_tokenizer.py +++ b/torchtune/models/clip/_tokenizer.py @@ -3,11 +3,9 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import html from os import PathLike from typing import List -import ftfy import regex as re import torch @@ -96,7 +94,7 @@ def encode(self, text: str) -> List[int]: Returns: List[int]: The encoded list of token ids. """ - text = _clean_text(text).lower() + text = text.lower() tokens = [self.sot_token] for token in re.findall(self.pat, text): @@ -208,9 +206,3 @@ def _get_pairs(word): pairs.add((prev_char, char)) prev_char = char return pairs - - -def _clean_text(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)).strip() - return re.sub(r"\s+", " ", text).strip() From a3c90ace19ae0b44ae89ea6ee5897da85eecd266 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Mon, 11 Nov 2024 15:05:49 -0800 Subject: [PATCH 05/16] clip text encoder unit test --- .../models/clip/test_clip_text_encoder.py | 55 +++++++++++++++++++ torchtune/models/clip/_model_builders.py | 2 +- 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 tests/torchtune/models/clip/test_clip_text_encoder.py diff --git a/tests/torchtune/models/clip/test_clip_text_encoder.py b/tests/torchtune/models/clip/test_clip_text_encoder.py new file mode 100644 index 0000000000..48bccc3bf7 --- /dev/null +++ b/tests/torchtune/models/clip/test_clip_text_encoder.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchtune.models.clip._text_encoder import CLIPTextEncoder +from torchtune.training.seed import set_seed + +VOCAB_SIZE = 512 +MAX_SEQ_LEN = 77 +BSZ = 2 +EMBED_DIM = 4 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestClipTextEncoder: + @pytest.fixture + def model(self): + model = CLIPTextEncoder( + vocab_size=VOCAB_SIZE, + max_seq_len=MAX_SEQ_LEN, + embed_dim=EMBED_DIM, + num_heads=2, + num_layers=2, + ) + + for param in model.parameters(): + param.data.uniform_(0, 1) + + return model + + @pytest.fixture + def inputs(self): + return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN)) + + def test_forward(self, model, inputs): + actual = model(inputs) + expected = torch.tensor( + [[0.1850, 0.8112, 1.4308, 0.0021], [0.1903, 0.8700, 1.3892, -0.6564]] + ) + assert actual.shape == (BSZ, EMBED_DIM) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + def test_backward(self, model, inputs): + y = model(inputs) + loss = y.mean() + loss.backward() diff --git a/torchtune/models/clip/_model_builders.py b/torchtune/models/clip/_model_builders.py index 7949f65829..18aa331d4b 100644 --- a/torchtune/models/clip/_model_builders.py +++ b/torchtune/models/clip/_model_builders.py @@ -31,7 +31,7 @@ def clip_tokenizer( return CLIPTokenizer(merges_path, max_seq_len=max_seq_len, truncate=truncate) -def clip_text_encoder_large() -> CLIPTextEncoder: +def clip_text_vit_large_patch14() -> CLIPTextEncoder: """ Builder for the CLIP text encoder for CLIP-ViT-L/14. From 1dbe939de688f6e844dac30d0e9158d8d4b6539c Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Mon, 11 Nov 2024 16:37:54 -0800 Subject: [PATCH 06/16] address comments --- .../models/clip/test_clip_text_encoder.py | 6 +- torchtune/models/clip/_component_builders.py | 64 +++++++++++++-- torchtune/models/clip/_convert_weights.py | 32 ++++---- torchtune/models/clip/_model_builders.py | 3 +- torchtune/models/clip/_text_encoder.py | 81 ++++++++++--------- 5 files changed, 122 insertions(+), 64 deletions(-) diff --git a/tests/torchtune/models/clip/test_clip_text_encoder.py b/tests/torchtune/models/clip/test_clip_text_encoder.py index 48bccc3bf7..2b50ae8718 100644 --- a/tests/torchtune/models/clip/test_clip_text_encoder.py +++ b/tests/torchtune/models/clip/test_clip_text_encoder.py @@ -7,7 +7,7 @@ import pytest import torch -from torchtune.models.clip._text_encoder import CLIPTextEncoder +from torchtune.models.clip._component_builders import clip_text_encoder from torchtune.training.seed import set_seed VOCAB_SIZE = 512 @@ -24,7 +24,7 @@ def random(): class TestClipTextEncoder: @pytest.fixture def model(self): - model = CLIPTextEncoder( + model = clip_text_encoder( vocab_size=VOCAB_SIZE, max_seq_len=MAX_SEQ_LEN, embed_dim=EMBED_DIM, @@ -44,7 +44,7 @@ def inputs(self): def test_forward(self, model, inputs): actual = model(inputs) expected = torch.tensor( - [[0.1850, 0.8112, 1.4308, 0.0021], [0.1903, 0.8700, 1.3892, -0.6564]] + [[0.2195, 1.3941, 0.6295, -0.1026], [0.2418, 1.4928, 0.6177, -0.0863]] ) assert actual.shape == (BSZ, EMBED_DIM) torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index 772d1e32df..22af8e2665 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -9,12 +9,13 @@ import torch from torch import nn + from torchtune.models.clip._position_embeddings import ( TiledTokenPositionalEmbedding, TilePositionalEmbedding, TokenPositionalEmbedding, ) - +from torchtune.models.clip._text_encoder import CLIPTextEncoder from torchtune.modules import ( FeedForward, Fp32LayerNorm, @@ -22,11 +23,9 @@ MultiHeadAttention, TransformerSelfAttentionLayer, ) - +from torchtune.modules.activations import QuickGELU from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook - -from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear - +from torchtune.modules.peft import LORA_ATTN_MODULES, DoRALinear, LoRALinear from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer @@ -157,6 +156,61 @@ def clip_vision_encoder( ) +def clip_text_encoder( + vocab_size: int = 49408, + max_seq_len: int = 77, + embed_dim: int = 768, + num_heads: int = 12, + num_layers: int = 12, + norm_eps: float = 1e-5, +): + """ + Text encoder for CLIP. + + Args: + vocab_size (int): size of the vocabulary, default 49408 + max_seq_len (int): context size, default 77 + embed_dim (int): embedding/model dimension size, default 768 + num_heads (int): number of attention heads, default 12 + num_layers (int): number of transformer layers, default 12 + norm_eps (float): small value added to denominator for numerical stability, default 1e-5 + + Returns: + CLIPTextEncoder + """ + attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=embed_dim // num_heads, + q_proj=nn.Linear(embed_dim, embed_dim), + k_proj=nn.Linear(embed_dim, embed_dim), + v_proj=nn.Linear(embed_dim, embed_dim), + output_proj=nn.Linear(embed_dim, embed_dim), + ) + mlp = clip_mlp( + in_dim=embed_dim, + out_dim=embed_dim, + hidden_dim=embed_dim * 4, + activation=QuickGELU(), + ) + encoder_layer = TransformerSelfAttentionLayer( + attn=attn, + mlp=mlp, + sa_norm=nn.LayerNorm(embed_dim, eps=norm_eps), + mlp_norm=nn.LayerNorm(embed_dim, eps=norm_eps), + ) + final_norm = nn.LayerNorm(embed_dim, eps=norm_eps) + return CLIPTextEncoder( + layers=encoder_layer, + final_norm=final_norm, + vocab_size=vocab_size, + max_seq_len=max_seq_len, + embed_dim=embed_dim, + num_layers=num_layers, + ) + + def clip_mlp( in_dim: int, out_dim: int, diff --git a/torchtune/models/clip/_convert_weights.py b/torchtune/models/clip/_convert_weights.py index d1d17a11b5..004100e928 100644 --- a/torchtune/models/clip/_convert_weights.py +++ b/torchtune/models/clip/_convert_weights.py @@ -10,22 +10,22 @@ _FROM_HF = { "text_model.embeddings.token_embedding.weight": "token_embedding.weight", "text_model.embeddings.position_embedding.weight": "position_embedding", - "text_model.encoder.layers.{}.layer_norm1.weight": "encoder.{}.sa_norm.weight", - "text_model.encoder.layers.{}.layer_norm1.bias": "encoder.{}.sa_norm.bias", - "text_model.encoder.layers.{}.layer_norm2.weight": "encoder.{}.mlp_norm.weight", - "text_model.encoder.layers.{}.layer_norm2.bias": "encoder.{}.mlp_norm.bias", - "text_model.encoder.layers.{}.mlp.fc1.weight": "encoder.{}.mlp.w1.weight", - "text_model.encoder.layers.{}.mlp.fc1.bias": "encoder.{}.mlp.w1.bias", - "text_model.encoder.layers.{}.mlp.fc2.weight": "encoder.{}.mlp.w2.weight", - "text_model.encoder.layers.{}.mlp.fc2.bias": "encoder.{}.mlp.w2.bias", - "text_model.encoder.layers.{}.self_attn.q_proj.weight": "encoder.{}.attn.q_proj.weight", - "text_model.encoder.layers.{}.self_attn.q_proj.bias": "encoder.{}.attn.q_proj.bias", - "text_model.encoder.layers.{}.self_attn.k_proj.weight": "encoder.{}.attn.k_proj.weight", - "text_model.encoder.layers.{}.self_attn.k_proj.bias": "encoder.{}.attn.k_proj.bias", - "text_model.encoder.layers.{}.self_attn.v_proj.weight": "encoder.{}.attn.v_proj.weight", - "text_model.encoder.layers.{}.self_attn.v_proj.bias": "encoder.{}.attn.v_proj.bias", - "text_model.encoder.layers.{}.self_attn.out_proj.bias": "encoder.{}.attn.output_proj.bias", - "text_model.encoder.layers.{}.self_attn.out_proj.weight": "encoder.{}.attn.output_proj.weight", + "text_model.encoder.layers.{}.layer_norm1.weight": "layers.{}.sa_norm.weight", + "text_model.encoder.layers.{}.layer_norm1.bias": "layers.{}.sa_norm.bias", + "text_model.encoder.layers.{}.layer_norm2.weight": "layers.{}.mlp_norm.weight", + "text_model.encoder.layers.{}.layer_norm2.bias": "layers.{}.mlp_norm.bias", + "text_model.encoder.layers.{}.mlp.fc1.weight": "layers.{}.mlp.w1.weight", + "text_model.encoder.layers.{}.mlp.fc1.bias": "layers.{}.mlp.w1.bias", + "text_model.encoder.layers.{}.mlp.fc2.weight": "layers.{}.mlp.w2.weight", + "text_model.encoder.layers.{}.mlp.fc2.bias": "layers.{}.mlp.w2.bias", + "text_model.encoder.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", + "text_model.encoder.layers.{}.self_attn.q_proj.bias": "layers.{}.attn.q_proj.bias", + "text_model.encoder.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", + "text_model.encoder.layers.{}.self_attn.k_proj.bias": "layers.{}.attn.k_proj.bias", + "text_model.encoder.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", + "text_model.encoder.layers.{}.self_attn.v_proj.bias": "layers.{}.attn.v_proj.bias", + "text_model.encoder.layers.{}.self_attn.out_proj.bias": "layers.{}.attn.output_proj.bias", + "text_model.encoder.layers.{}.self_attn.out_proj.weight": "layers.{}.attn.output_proj.weight", "text_model.final_layer_norm.weight": "final_norm.weight", "text_model.final_layer_norm.bias": "final_norm.bias", } diff --git a/torchtune/models/clip/_model_builders.py b/torchtune/models/clip/_model_builders.py index 18aa331d4b..36a9b10071 100644 --- a/torchtune/models/clip/_model_builders.py +++ b/torchtune/models/clip/_model_builders.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from os import PathLike +from torchtune.models.clip._component_builders import clip_text_encoder from torchtune.models.clip._text_encoder import CLIPTextEncoder from torchtune.models.clip._tokenizer import CLIPTokenizer from torchtune.models.clip._transform import CLIPImageTransform @@ -38,7 +39,7 @@ def clip_text_vit_large_patch14() -> CLIPTextEncoder: Returns: CLIPTextEncoder: Instantiation of the CLIP text encoder """ - return CLIPTextEncoder( + return clip_text_encoder( vocab_size=49408, max_seq_len=77, embed_dim=768, diff --git a/torchtune/models/clip/_text_encoder.py b/torchtune/models/clip/_text_encoder.py index b8fb102349..686d7cefa3 100644 --- a/torchtune/models/clip/_text_encoder.py +++ b/torchtune/models/clip/_text_encoder.py @@ -4,15 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy +from typing import List, Optional, Union + import torch from torch import nn, Tensor -from torchtune.modules import ( - FeedForward, - MultiHeadAttention, - TransformerSelfAttentionLayer, -) -from torchtune.modules.activations import QuickGELU +from torchtune.modules.attention_utils import _MaskType class CLIPTextEncoder(nn.Module): @@ -20,60 +18,60 @@ class CLIPTextEncoder(nn.Module): Text encoder for CLIP. Args: + layers (Union[nn.Module, List[nn.Module], nn.ModuleList]): A single encoder layer, an + nn.ModuleList of layers, or a list of layers. + final_norm (nn.Module): Callable that applies normalization to the output of the encoder vocab_size (int): size of the vocabulary, default 49408 max_seq_len (int): context size, default 77 embed_dim (int): embedding/model dimension size, default 768 - num_heads (int): number of attention heads, default 12 num_layers (int): number of transformer layers, default 12 - norm_eps (float): small value added to denominator for numerical stability, default 1e-5 + + Raises: + AssertionError: num_layers is set and layer is a list + AssertionError: num_layers is not set and layer is an nn.Module """ def __init__( self, + *, + layers: Union[nn.Module, List[nn.Module], nn.ModuleList], + final_norm: nn.Module, vocab_size: int = 49408, max_seq_len: int = 77, embed_dim: int = 768, - num_heads: int = 12, num_layers: int = 12, - norm_eps: float = 1e-5, ): super().__init__() + if isinstance(layers, nn.ModuleList): + pass + elif isinstance(layers, list): + layers = nn.ModuleList(layers) + else: + if not isinstance(layers, nn.Module): + raise AssertionError("num_layers is defined, layers must be a module") + if num_layers is None: + raise AssertionError("num_layers is not defined, layers must be a list") + layers = nn.ModuleList([copy.deepcopy(layers) for i in range(num_layers)]) + + self.layers = layers + self.final_norm = final_norm self.max_seq_len = max_seq_len self.token_embedding = nn.Embedding(vocab_size, embed_dim) self.position_embedding = nn.Parameter(torch.empty(max_seq_len, embed_dim)) - self.encoder = nn.Sequential( - *[ - TransformerSelfAttentionLayer( - attn=MultiHeadAttention( - embed_dim=embed_dim, - num_heads=num_heads, - num_kv_heads=num_heads, - head_dim=embed_dim // num_heads, - q_proj=nn.Linear(embed_dim, embed_dim), - k_proj=nn.Linear(embed_dim, embed_dim), - v_proj=nn.Linear(embed_dim, embed_dim), - output_proj=nn.Linear(embed_dim, embed_dim), - ), - mlp=FeedForward( - gate_proj=nn.Linear(embed_dim, embed_dim * 4), - down_proj=nn.Linear(embed_dim * 4, embed_dim), - activation=QuickGELU(), - ), - sa_norm=nn.LayerNorm(embed_dim, eps=norm_eps), - mlp_norm=nn.LayerNorm(embed_dim, eps=norm_eps), - ) - for _ in range(num_layers) - ] - ) - - self.final_norm = nn.LayerNorm(embed_dim, eps=norm_eps) - - def forward(self, tokens: Tensor) -> Tensor: + def forward( + self, + tokens: Tensor, + *, + mask: Optional[_MaskType] = None, + ) -> Tensor: """ Args: tokens (Tensor): input tensor with shape ``[b x s]`` + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. + Default is None. Returns: Tensor: output tensor with shape [b x d] @@ -98,11 +96,16 @@ def forward(self, tokens: Tensor) -> Tensor: x = self.token_embedding(tokens) + self.position_embedding # Encoder [b, s, d] -> [b, s, d] - x = self.encoder(x) + for layer in self.layers: + x = layer( + x, + mask=mask, + ) x = self.final_norm(x) # Select the output of the EOS token for each encoding in the batch # [b, s, d] -> [b, d] + # TODO: handle the case when the EOS token is not the highest token ID eos_token_positions = tokens.argmax(dim=-1) x = x.take_along_dim(eos_token_positions.view(-1, 1, 1), dim=1).squeeze(dim=1) From e6b3d1951411b28fb69166123007500a0ced5a81 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Mon, 11 Nov 2024 16:42:40 -0800 Subject: [PATCH 07/16] move __call__ --- torchtune/models/clip/_tokenizer.py | 38 ++++++++++++++--------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py index 5d1f239466..077b779577 100644 --- a/torchtune/models/clip/_tokenizer.py +++ b/torchtune/models/clip/_tokenizer.py @@ -65,25 +65,6 @@ def __init__( "<|endoftext|>": "<|endoftext|>", } - def __call__(self, texts: List[str]) -> torch.Tensor: - """ - Returns a Tensor with the tokenized representation of given input strings - - Args: - texts (List[str]): list of input strings to tokenize - - Returns: - torch.Tensor: int tensor with shape [len(texts), max_seq_len] - """ - assert isinstance(texts, list) - result = torch.full( - (len(texts), self.max_seq_len), self.pad_token, dtype=torch.int - ) - for i, text in enumerate(texts): - tokens = self.encode(text) - result[i, : len(tokens)] = torch.tensor(tokens) - return result - def encode(self, text: str) -> List[int]: """ Given a string, return the encoded list of token ids. @@ -133,6 +114,25 @@ def decode(self, tokens: List[int]) -> str: .replace("", " ") ) + def __call__(self, texts: List[str]) -> torch.Tensor: + """ + Returns a Tensor with the tokenized representation of given input strings + + Args: + texts (List[str]): list of input strings to tokenize + + Returns: + torch.Tensor: int tensor with shape [len(texts), max_seq_len] + """ + assert isinstance(texts, list) + result = torch.full( + (len(texts), self.max_seq_len), self.pad_token, dtype=torch.int + ) + for i, text in enumerate(texts): + tokens = self.encode(text) + result[i, : len(tokens)] = torch.tensor(tokens) + return result + def _bpe(self, token): if token in self.cache: return self.cache[token] From c4e700b6349be98620b762594163e9e078c7469b Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Thu, 14 Nov 2024 11:17:10 -0800 Subject: [PATCH 08/16] addressing comments --- torchtune/models/clip/__init__.py | 7 ++++-- torchtune/models/clip/_component_builders.py | 12 ++++++++-- torchtune/models/clip/_model_builders.py | 6 ++--- torchtune/models/clip/_text_encoder.py | 24 ++++---------------- torchtune/models/clip/_tokenizer.py | 8 +++---- torchtune/modules/activations.py | 17 -------------- 6 files changed, 25 insertions(+), 49 deletions(-) delete mode 100644 torchtune/modules/activations.py diff --git a/torchtune/models/clip/__init__.py b/torchtune/models/clip/__init__.py index 5361c8968c..a6ea4ecb6a 100644 --- a/torchtune/models/clip/__init__.py +++ b/torchtune/models/clip/__init__.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._component_builders import clip_mlp, clip_vision_encoder - +from ._component_builders import clip_mlp, clip_text_encoder, clip_vision_encoder +from ._model_builders import clip_text_vit_large_patch14, clip_tokenizer from ._position_embeddings import ( TiledTokenPositionalEmbedding, TilePositionalEmbedding, @@ -15,7 +15,10 @@ __all__ = [ "clip_mlp", + "clip_text_encoder", "clip_vision_encoder", + "clip_text_vit_large_patch14", + "clip_tokenizer", "CLIPImageTransform", "TokenPositionalEmbedding", "TiledTokenPositionalEmbedding", diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index 22af8e2665..fbe868bc39 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -23,7 +23,6 @@ MultiHeadAttention, TransformerSelfAttentionLayer, ) -from torchtune.modules.activations import QuickGELU from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook from torchtune.modules.peft import LORA_ATTN_MODULES, DoRALinear, LoRALinear from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer @@ -202,7 +201,7 @@ def clip_text_encoder( ) final_norm = nn.LayerNorm(embed_dim, eps=norm_eps) return CLIPTextEncoder( - layers=encoder_layer, + layer=encoder_layer, final_norm=final_norm, vocab_size=vocab_size, max_seq_len=max_seq_len, @@ -577,3 +576,12 @@ def lora_clip_mlp( return FeedForward( gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation ) + + +class QuickGELU(nn.Module): + """ + Fast approximation of GELU. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) diff --git a/torchtune/models/clip/_model_builders.py b/torchtune/models/clip/_model_builders.py index 36a9b10071..6d72fb3921 100644 --- a/torchtune/models/clip/_model_builders.py +++ b/torchtune/models/clip/_model_builders.py @@ -12,7 +12,7 @@ def clip_tokenizer( - merges_path: PathLike, + path: PathLike, max_seq_len: int = 77, truncate: bool = True, ) -> CLIPTokenizer: @@ -20,7 +20,7 @@ def clip_tokenizer( Builder for the CLIP text tokenizer. Args: - merges_path (PathLike): Path to the CLIP merges file + path (PathLike): Path to the CLIP merges file max_seq_len (bool): Context length Default: 77 truncate (bool): Truncate the token sequence if it exceeds max_seq_len (otherwise raises AssertionError) @@ -29,7 +29,7 @@ def clip_tokenizer( Returns: CLIPTokenizer: Instantiation of the CLIP text tokenizer """ - return CLIPTokenizer(merges_path, max_seq_len=max_seq_len, truncate=truncate) + return CLIPTokenizer(path, max_seq_len=max_seq_len, truncate=truncate) def clip_text_vit_large_patch14() -> CLIPTextEncoder: diff --git a/torchtune/models/clip/_text_encoder.py b/torchtune/models/clip/_text_encoder.py index 686d7cefa3..12766ef367 100644 --- a/torchtune/models/clip/_text_encoder.py +++ b/torchtune/models/clip/_text_encoder.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy -from typing import List, Optional, Union +from typing import Optional import torch from torch import nn, Tensor @@ -18,23 +18,18 @@ class CLIPTextEncoder(nn.Module): Text encoder for CLIP. Args: - layers (Union[nn.Module, List[nn.Module], nn.ModuleList]): A single encoder layer, an - nn.ModuleList of layers, or a list of layers. + layer (nn.Module): A single encoder layer. final_norm (nn.Module): Callable that applies normalization to the output of the encoder vocab_size (int): size of the vocabulary, default 49408 max_seq_len (int): context size, default 77 embed_dim (int): embedding/model dimension size, default 768 num_layers (int): number of transformer layers, default 12 - - Raises: - AssertionError: num_layers is set and layer is a list - AssertionError: num_layers is not set and layer is an nn.Module """ def __init__( self, *, - layers: Union[nn.Module, List[nn.Module], nn.ModuleList], + layer: nn.Module, final_norm: nn.Module, vocab_size: int = 49408, max_seq_len: int = 77, @@ -42,18 +37,7 @@ def __init__( num_layers: int = 12, ): super().__init__() - if isinstance(layers, nn.ModuleList): - pass - elif isinstance(layers, list): - layers = nn.ModuleList(layers) - else: - if not isinstance(layers, nn.Module): - raise AssertionError("num_layers is defined, layers must be a module") - if num_layers is None: - raise AssertionError("num_layers is not defined, layers must be a list") - layers = nn.ModuleList([copy.deepcopy(layers) for i in range(num_layers)]) - - self.layers = layers + self.layers = nn.ModuleList([copy.deepcopy(layer) for i in range(num_layers)]) self.final_norm = final_norm self.max_seq_len = max_seq_len diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py index 077b779577..98e10ef241 100644 --- a/torchtune/models/clip/_tokenizer.py +++ b/torchtune/models/clip/_tokenizer.py @@ -20,14 +20,12 @@ class CLIPTokenizer(BaseTokenizer): https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py Args: - merges_path (PathLike): the path to the CLIP merges file + path (PathLike): the path to the CLIP merges file max_seq_len (int): the context length (all CLIP models use 77) truncate (bool): whether to truncate the text when longer than max_seq_len """ - def __init__( - self, merges_path: PathLike, max_seq_len: int = 77, truncate: bool = True - ): + def __init__(self, path: PathLike, max_seq_len: int = 77, truncate: bool = True): self.max_seq_len = max_seq_len self.truncate = truncate @@ -35,7 +33,7 @@ def __init__( self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} merges = [] - with open(merges_path, encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: for i, line in enumerate(f): line = line.strip() if (i == 0 and line.startswith("#version:")) or not line: diff --git a/torchtune/modules/activations.py b/torchtune/modules/activations.py deleted file mode 100644 index 2b43908df0..0000000000 --- a/torchtune/modules/activations.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from torch import nn - - -class QuickGELU(nn.Module): - """ - Fast approximation of GELU. - """ - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x * torch.sigmoid(1.702 * x) From 5aa7c9fec561d66a1de1bb74f130c462ddbdb5b8 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Thu, 14 Nov 2024 12:36:55 -0800 Subject: [PATCH 09/16] moving quickgelu --- torchtune/models/clip/_component_builders.py | 11 +---------- torchtune/models/clip/_text_encoder.py | 9 +++++++++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index caad754419..4a1aa2bc89 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -15,7 +15,7 @@ TilePositionalEmbedding, TokenPositionalEmbedding, ) -from torchtune.models.clip._text_encoder import CLIPTextEncoder +from torchtune.models.clip._text_encoder import CLIPTextEncoder, QuickGELU from torchtune.modules import ( FeedForward, Fp32LayerNorm, @@ -597,12 +597,3 @@ def lora_clip_mlp( return FeedForward( gate_proj=gate_proj, down_proj=down_proj, up_proj=None, activation=activation ) - - -class QuickGELU(nn.Module): - """ - Fast approximation of GELU. - """ - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x * torch.sigmoid(1.702 * x) diff --git a/torchtune/models/clip/_text_encoder.py b/torchtune/models/clip/_text_encoder.py index 12766ef367..c9c2ccfbba 100644 --- a/torchtune/models/clip/_text_encoder.py +++ b/torchtune/models/clip/_text_encoder.py @@ -94,3 +94,12 @@ def forward( x = x.take_along_dim(eos_token_positions.view(-1, 1, 1), dim=1).squeeze(dim=1) return x + + +class QuickGELU(nn.Module): + """ + Fast approximation of GELU. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) From 69a5a167cf19a5ac3f3666fb07193dcb5ee7b2e1 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Mon, 18 Nov 2024 17:48:22 -0800 Subject: [PATCH 10/16] addressing comments and making tokenizer more efficient --- torchtune/models/clip/_model_builders.py | 6 +- torchtune/models/clip/_tokenizer.py | 71 ++++++++++++++++++------ 2 files changed, 55 insertions(+), 22 deletions(-) diff --git a/torchtune/models/clip/_model_builders.py b/torchtune/models/clip/_model_builders.py index 6d72fb3921..d218844266 100644 --- a/torchtune/models/clip/_model_builders.py +++ b/torchtune/models/clip/_model_builders.py @@ -3,8 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from os import PathLike - from torchtune.models.clip._component_builders import clip_text_encoder from torchtune.models.clip._text_encoder import CLIPTextEncoder from torchtune.models.clip._tokenizer import CLIPTokenizer @@ -12,7 +10,7 @@ def clip_tokenizer( - path: PathLike, + path: str, max_seq_len: int = 77, truncate: bool = True, ) -> CLIPTokenizer: @@ -20,7 +18,7 @@ def clip_tokenizer( Builder for the CLIP text tokenizer. Args: - path (PathLike): Path to the CLIP merges file + path (str): Path to the CLIP merges file max_seq_len (bool): Context length Default: 77 truncate (bool): Truncate the token sequence if it exceeds max_seq_len (otherwise raises AssertionError) diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py index 98e10ef241..ef3d42101d 100644 --- a/torchtune/models/clip/_tokenizer.py +++ b/torchtune/models/clip/_tokenizer.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from os import PathLike from typing import List import regex as re @@ -11,6 +10,8 @@ from torchtune.modules.tokenizers._utils import BaseTokenizer +WORD_BOUNDARY = "" + class CLIPTokenizer(BaseTokenizer): """ @@ -20,28 +21,22 @@ class CLIPTokenizer(BaseTokenizer): https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py Args: - path (PathLike): the path to the CLIP merges file + path (str): the path to the CLIP merges file max_seq_len (int): the context length (all CLIP models use 77) truncate (bool): whether to truncate the text when longer than max_seq_len """ - def __init__(self, path: PathLike, max_seq_len: int = 77, truncate: bool = True): + def __init__(self, path: str, max_seq_len: int = 77, truncate: bool = True): self.max_seq_len = max_seq_len self.truncate = truncate self.byte_encoder = _bytes_to_unicode() self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = [] - with open(path, encoding="utf-8") as f: - for i, line in enumerate(f): - line = line.strip() - if (i == 0 and line.startswith("#version:")) or not line: - continue - merges.append(tuple(line.split())) + merges = _load_merges(path) vocab = list(self.byte_encoder.values()) - vocab.extend([v + "" for v in vocab]) + vocab.extend([v + WORD_BOUNDARY for v in vocab]) vocab.extend(["".join(merge) for merge in merges]) vocab.extend(["<|startoftext|>", "<|endoftext|>"]) @@ -73,7 +68,7 @@ def encode(self, text: str) -> List[int]: Returns: List[int]: The encoded list of token ids. """ - text = text.lower() + text = _clean_text(text).lower() tokens = [self.sot_token] for token in re.findall(self.pat, text): @@ -109,7 +104,7 @@ def decode(self, tokens: List[int]) -> str: return ( bytearray([self.byte_decoder[c] for c in text]) .decode("utf-8", errors="replace") - .replace("", " ") + .replace(WORD_BOUNDARY, " ") ) def __call__(self, texts: List[str]) -> torch.Tensor: @@ -131,24 +126,44 @@ def __call__(self, texts: List[str]) -> torch.Tensor: result[i, : len(tokens)] = torch.tensor(tokens) return result - def _bpe(self, token): + def _bpe(self, token: str) -> str: + """ + Performs byte-pair encoding on a single token. + + Args: + token (str): The input token to encode + + Returns: + str: The encoded token with merge rules applied + """ if token in self.cache: return self.cache[token] - word = tuple(token[:-1]) + (token[-1] + "",) - pairs = _get_pairs(word) + if len(token) < 2: + return token + WORD_BOUNDARY - if not pairs: - return token + "" + # create the initial "word" (seq of "symbols" i.e. characters and merged subwords) + # by converting the token to tuple of characters and add to the last character + word = tuple(token[:-1]) + (token[-1] + WORD_BOUNDARY,) + # get all pairs of adjacent characters + pairs = _get_pairs(word) + + # merge symbol pairs until there are no possible merges left while True: + # find the pair with the lowest rank (highest priority to merge) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + # end if there are no pairs to merge if bigram not in self.bpe_ranks: break + + # create the next "word" by merging any adjacent symbols that match the bigram first, second = bigram new_word = [] i = 0 while i < len(word): + # find next potentially mergeable position and copy over any skipped characters + # if no more merge positions found, copy remaining characters and finish try: j = word.index(first, i) new_word.extend(word[i:j]) @@ -157,6 +172,7 @@ def _bpe(self, token): new_word.extend(word[i:]) break + # check if we can perform a merge if word[i] == first and i < len(word) - 1 and word[i + 1] == second: new_word.append(first + second) i += 2 @@ -164,8 +180,12 @@ def _bpe(self, token): new_word.append(word[i]) i += 1 word = tuple(new_word) + + # end if the new "word" is fully merged if len(word) == 1: break + + # get all pairs of adjacent symbols in the new word pairs = _get_pairs(word) word = " ".join(word) @@ -204,3 +224,18 @@ def _get_pairs(word): pairs.add((prev_char, char)) prev_char = char return pairs + + +def _clean_text(text): + return text.replace("’", "'") + + +def _load_merges(path): + merges = [] + with open(path, encoding="utf-8") as f: + for i, line in enumerate(f): + line = line.strip() + if (i == 0 and line.startswith("#version:")) or not line: + continue + merges.append(tuple(line.split())) + return merges From 5fe86ae50ae0b20867d49a6f168e12a2773afbaf Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Mon, 18 Nov 2024 19:12:17 -0800 Subject: [PATCH 11/16] type hints --- torchtune/models/clip/_tokenizer.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py index ef3d42101d..73614dc7f4 100644 --- a/torchtune/models/clip/_tokenizer.py +++ b/torchtune/models/clip/_tokenizer.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import List +from typing import Dict, List, Set, Tuple import regex as re import torch @@ -129,12 +129,6 @@ def __call__(self, texts: List[str]) -> torch.Tensor: def _bpe(self, token: str) -> str: """ Performs byte-pair encoding on a single token. - - Args: - token (str): The input token to encode - - Returns: - str: The encoded token with merge rules applied """ if token in self.cache: return self.cache[token] @@ -193,7 +187,7 @@ def _bpe(self, token: str) -> str: return word -def _bytes_to_unicode(): +def _bytes_to_unicode() -> Dict[int, str]: """ Returns list of utf-8 byte and a corresponding list of unicode strings. """ @@ -213,7 +207,7 @@ def _bytes_to_unicode(): return dict(zip(bs, cs)) -def _get_pairs(word): +def _get_pairs(word: Tuple[str, ...]) -> Set[Tuple[str, str]]: """ Return set of symbol pairs in a word. Word is represented as tuple of symbols (symbols being variable-length strings). @@ -226,11 +220,14 @@ def _get_pairs(word): return pairs -def _clean_text(text): +def _clean_text(text: str) -> str: + """ + Minimal version of CLIP's text cleaning via the `ftfy` package. + """ return text.replace("’", "'") -def _load_merges(path): +def _load_merges(path: str) -> List[Tuple[str, str]]: merges = [] with open(path, encoding="utf-8") as f: for i, line in enumerate(f): From 4c6ef70e2ee561f6f129b200c6384c215b0d6673 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Tue, 19 Nov 2024 19:51:25 -0800 Subject: [PATCH 12/16] tokenizer __call__ --- .../models/clip/test_clip_tokenizer.py | 63 +++++++++---------- torchtune/models/clip/_tokenizer.py | 25 ++++---- 2 files changed, 41 insertions(+), 47 deletions(-) diff --git a/tests/torchtune/models/clip/test_clip_tokenizer.py b/tests/torchtune/models/clip/test_clip_tokenizer.py index d07f9ef5b9..977cd4f028 100644 --- a/tests/torchtune/models/clip/test_clip_tokenizer.py +++ b/tests/torchtune/models/clip/test_clip_tokenizer.py @@ -14,48 +14,45 @@ class TestCLIPTokenizer: def tokenizer(self): return clip_tokenizer(ASSETS / "tiny_bpe_merges.txt") - def test_tokenization(self, tokenizer): + def test_encoding(self, tokenizer): texts = [ "a cow jumping over the moon", "a helpful AI assistant", ] correct_tokens = [ - _pad( - [ - 2416, - 320, - 66, - 78, - 342, - 73, - 669, - 79, - 515, - 326, - 1190, - 337, - 673, - 324, - 76, - 819, - 333, - 2417, - ] - ), - _pad( - [2416, 320, 516, 75, 79, 69, 84, 331, 64, 328, 813, 667, 540, 339, 2417] - ), + [ + 2416, + 320, + 66, + 78, + 342, + 73, + 669, + 79, + 515, + 326, + 1190, + 337, + 673, + 324, + 76, + 819, + 333, + 2417, + ], + [2416, 320, 516, 75, 79, 69, 84, 331, 64, 328, 813, 667, 540, 339, 2417], ] - tokens_tensor = tokenizer(texts) - assert tokens_tensor.tolist() == correct_tokens + for text, correct in zip(texts, correct_tokens): + tokens = tokenizer.encode(text) + assert tokens == correct def test_decoding(self, tokenizer): text = "this is torchtune" decoded_text = "<|startoftext|>this is torchtune <|endoftext|>" assert decoded_text == tokenizer.decode(tokenizer.encode(text)) - -def _pad(tokens, max_seq_len=77, pad_token=2417): - while len(tokens) < max_seq_len: - tokens.append(pad_token) - return tokens + def test_call(self, tokenizer): + sample = {"text": "hello world"} + sample = tokenizer(sample) + assert "text" not in sample + assert "tokens" in sample diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py index 73614dc7f4..69fed32c72 100644 --- a/torchtune/models/clip/_tokenizer.py +++ b/torchtune/models/clip/_tokenizer.py @@ -3,10 +3,9 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List, Set, Tuple +from typing import Any, Dict, List, Mapping, Set, Tuple import regex as re -import torch from torchtune.modules.tokenizers._utils import BaseTokenizer @@ -107,24 +106,22 @@ def decode(self, tokens: List[int]) -> str: .replace(WORD_BOUNDARY, " ") ) - def __call__(self, texts: List[str]) -> torch.Tensor: + def __call__( + self, sample: Mapping[str, Any], inference: bool = False + ) -> Mapping[str, Any]: """ - Returns a Tensor with the tokenized representation of given input strings + Tokenize the "text" field in the sample. Args: - texts (List[str]): list of input strings to tokenize + sample (Mapping[str, Any]): A sample with a "text" field containing a string to tokenize + inference (bool): Unused by this tokenizer Returns: - torch.Tensor: int tensor with shape [len(texts), max_seq_len] + Mapping[str, Any]: The sample with added "tokens" field and the "messages" field removed. """ - assert isinstance(texts, list) - result = torch.full( - (len(texts), self.max_seq_len), self.pad_token, dtype=torch.int - ) - for i, text in enumerate(texts): - tokens = self.encode(text) - result[i, : len(tokens)] = torch.tensor(tokens) - return result + text = sample.pop("text") + sample["tokens"] = self.encode(text) + return sample def _bpe(self, token: str) -> str: """ From 3baea1ce3e48d11ffab5ad7ed7cef216bee8e85d Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Wed, 20 Nov 2024 10:43:55 -0800 Subject: [PATCH 13/16] addressing comments --- torchtune/models/clip/_component_builders.py | 12 ++++++------ torchtune/models/clip/_model_builders.py | 9 +++++---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index 337c4d43a9..ba5f0226a9 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -180,22 +180,22 @@ def clip_vision_encoder( def clip_text_encoder( + embed_dim: int, + num_heads: int, + num_layers: int, vocab_size: int = 49408, max_seq_len: int = 77, - embed_dim: int = 768, - num_heads: int = 12, - num_layers: int = 12, norm_eps: float = 1e-5, ): """ Text encoder for CLIP. Args: + embed_dim (int): embedding/model dimension size + num_heads (int): number of attention heads + num_layers (int): number of transformer layers vocab_size (int): size of the vocabulary, default 49408 max_seq_len (int): context size, default 77 - embed_dim (int): embedding/model dimension size, default 768 - num_heads (int): number of attention heads, default 12 - num_layers (int): number of transformer layers, default 12 norm_eps (float): small value added to denominator for numerical stability, default 1e-5 Returns: diff --git a/torchtune/models/clip/_model_builders.py b/torchtune/models/clip/_model_builders.py index d218844266..8767b078e3 100644 --- a/torchtune/models/clip/_model_builders.py +++ b/torchtune/models/clip/_model_builders.py @@ -19,8 +19,7 @@ def clip_tokenizer( Args: path (str): Path to the CLIP merges file - max_seq_len (bool): Context length - Default: 77 + max_seq_len (bool): Context length. Default: 77 truncate (bool): Truncate the token sequence if it exceeds max_seq_len (otherwise raises AssertionError) Default: True @@ -34,15 +33,17 @@ def clip_text_vit_large_patch14() -> CLIPTextEncoder: """ Builder for the CLIP text encoder for CLIP-ViT-L/14. + https://arxiv.org/abs/2103.00020 + Returns: CLIPTextEncoder: Instantiation of the CLIP text encoder """ return clip_text_encoder( - vocab_size=49408, - max_seq_len=77, embed_dim=768, num_heads=12, num_layers=12, + vocab_size=49408, + max_seq_len=77, norm_eps=1e-5, ) From 10c1b0d54432d28315f3d9e7235b6ee86dbaad41 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Wed, 20 Nov 2024 13:49:11 -0800 Subject: [PATCH 14/16] configurable eot token --- torchtune/models/clip/_text_encoder.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchtune/models/clip/_text_encoder.py b/torchtune/models/clip/_text_encoder.py index c9c2ccfbba..ee270c3779 100644 --- a/torchtune/models/clip/_text_encoder.py +++ b/torchtune/models/clip/_text_encoder.py @@ -24,6 +24,7 @@ class CLIPTextEncoder(nn.Module): max_seq_len (int): context size, default 77 embed_dim (int): embedding/model dimension size, default 768 num_layers (int): number of transformer layers, default 12 + eot_token (int): the id of the end-of-text token (for selecting the final output) """ def __init__( @@ -35,11 +36,13 @@ def __init__( max_seq_len: int = 77, embed_dim: int = 768, num_layers: int = 12, + eot_token: int = 49407, ): super().__init__() self.layers = nn.ModuleList([copy.deepcopy(layer) for i in range(num_layers)]) self.final_norm = final_norm self.max_seq_len = max_seq_len + self.eot_token = eot_token self.token_embedding = nn.Embedding(vocab_size, embed_dim) self.position_embedding = nn.Parameter(torch.empty(max_seq_len, embed_dim)) @@ -87,10 +90,9 @@ def forward( ) x = self.final_norm(x) - # Select the output of the EOS token for each encoding in the batch + # Select the output of the EOT token for each encoding in the batch # [b, s, d] -> [b, d] - # TODO: handle the case when the EOS token is not the highest token ID - eos_token_positions = tokens.argmax(dim=-1) + eos_token_positions = (tokens == self.eot_token).int().argmax(dim=-1) x = x.take_along_dim(eos_token_positions.view(-1, 1, 1), dim=1).squeeze(dim=1) return x From ec75caeb43dc783894e0677177b211a2e4b71308 Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Wed, 20 Nov 2024 13:56:26 -0800 Subject: [PATCH 15/16] docstring --- torchtune/models/clip/_component_builders.py | 4 ++++ torchtune/models/clip/_model_builders.py | 4 +++- torchtune/models/clip/_text_encoder.py | 4 ++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index ba5f0226a9..c8d19aae41 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -190,6 +190,10 @@ def clip_text_encoder( """ Text encoder for CLIP. + CLIP is a model that encodes text and images into a shared vector space. + Blog post: https://openai.com/index/clip/ + Paper: https://arxiv.org/abs/2103.00020 + Args: embed_dim (int): embedding/model dimension size num_heads (int): number of attention heads diff --git a/torchtune/models/clip/_model_builders.py b/torchtune/models/clip/_model_builders.py index 8767b078e3..d640466def 100644 --- a/torchtune/models/clip/_model_builders.py +++ b/torchtune/models/clip/_model_builders.py @@ -33,7 +33,9 @@ def clip_text_vit_large_patch14() -> CLIPTextEncoder: """ Builder for the CLIP text encoder for CLIP-ViT-L/14. - https://arxiv.org/abs/2103.00020 + CLIP is a model that encodes text and images into a shared vector space. + Blog post: https://openai.com/index/clip/ + Paper: https://arxiv.org/abs/2103.00020 Returns: CLIPTextEncoder: Instantiation of the CLIP text encoder diff --git a/torchtune/models/clip/_text_encoder.py b/torchtune/models/clip/_text_encoder.py index ee270c3779..1bfb756664 100644 --- a/torchtune/models/clip/_text_encoder.py +++ b/torchtune/models/clip/_text_encoder.py @@ -17,6 +17,10 @@ class CLIPTextEncoder(nn.Module): """ Text encoder for CLIP. + CLIP is a model that encodes text and images into a shared vector space. + Blog post: https://openai.com/index/clip/ + Paper: https://arxiv.org/abs/2103.00020 + Args: layer (nn.Module): A single encoder layer. final_norm (nn.Module): Callable that applies normalization to the output of the encoder From c21569088b3da324460a79f14cddef2471dd17ed Mon Sep 17 00:00:00 2001 From: Calvin Pelletier Date: Wed, 20 Nov 2024 15:09:55 -0800 Subject: [PATCH 16/16] fix unit test --- tests/torchtune/models/clip/test_clip_text_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/torchtune/models/clip/test_clip_text_encoder.py b/tests/torchtune/models/clip/test_clip_text_encoder.py index 2b50ae8718..6ab388a61e 100644 --- a/tests/torchtune/models/clip/test_clip_text_encoder.py +++ b/tests/torchtune/models/clip/test_clip_text_encoder.py @@ -44,7 +44,7 @@ def inputs(self): def test_forward(self, model, inputs): actual = model(inputs) expected = torch.tensor( - [[0.2195, 1.3941, 0.6295, -0.1026], [0.2418, 1.4928, 0.6177, -0.0863]] + [[0.1915, 1.3982, 0.6298, -0.0966], [0.2276, 1.3785, 0.6309, -0.1066]] ) assert actual.shape == (BSZ, EMBED_DIM) torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)