diff --git a/tests/torchtune/models/clip/test_clip_text_encoder.py b/tests/torchtune/models/clip/test_clip_text_encoder.py new file mode 100644 index 0000000000..6ab388a61e --- /dev/null +++ b/tests/torchtune/models/clip/test_clip_text_encoder.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch + +from torchtune.models.clip._component_builders import clip_text_encoder +from torchtune.training.seed import set_seed + +VOCAB_SIZE = 512 +MAX_SEQ_LEN = 77 +BSZ = 2 +EMBED_DIM = 4 + + +@pytest.fixture(autouse=True) +def random(): + set_seed(0) + + +class TestClipTextEncoder: + @pytest.fixture + def model(self): + model = clip_text_encoder( + vocab_size=VOCAB_SIZE, + max_seq_len=MAX_SEQ_LEN, + embed_dim=EMBED_DIM, + num_heads=2, + num_layers=2, + ) + + for param in model.parameters(): + param.data.uniform_(0, 1) + + return model + + @pytest.fixture + def inputs(self): + return torch.randint(0, VOCAB_SIZE, (BSZ, MAX_SEQ_LEN)) + + def test_forward(self, model, inputs): + actual = model(inputs) + expected = torch.tensor( + [[0.1915, 1.3982, 0.6298, -0.0966], [0.2276, 1.3785, 0.6309, -0.1066]] + ) + assert actual.shape == (BSZ, EMBED_DIM) + torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) + + def test_backward(self, model, inputs): + y = model(inputs) + loss = y.mean() + loss.backward() diff --git a/tests/torchtune/models/clip/test_clip_tokenizer.py b/tests/torchtune/models/clip/test_clip_tokenizer.py new file mode 100644 index 0000000000..977cd4f028 --- /dev/null +++ b/tests/torchtune/models/clip/test_clip_tokenizer.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import pytest + +from tests.common import ASSETS +from torchtune.models.clip._model_builders import clip_tokenizer + + +class TestCLIPTokenizer: + @pytest.fixture + def tokenizer(self): + return clip_tokenizer(ASSETS / "tiny_bpe_merges.txt") + + def test_encoding(self, tokenizer): + texts = [ + "a cow jumping over the moon", + "a helpful AI assistant", + ] + correct_tokens = [ + [ + 2416, + 320, + 66, + 78, + 342, + 73, + 669, + 79, + 515, + 326, + 1190, + 337, + 673, + 324, + 76, + 819, + 333, + 2417, + ], + [2416, 320, 516, 75, 79, 69, 84, 331, 64, 328, 813, 667, 540, 339, 2417], + ] + for text, correct in zip(texts, correct_tokens): + tokens = tokenizer.encode(text) + assert tokens == correct + + def test_decoding(self, tokenizer): + text = "this is torchtune" + decoded_text = "<|startoftext|>this is torchtune <|endoftext|>" + assert decoded_text == tokenizer.decode(tokenizer.encode(text)) + + def test_call(self, tokenizer): + sample = {"text": "hello world"} + sample = tokenizer(sample) + assert "text" not in sample + assert "tokens" in sample diff --git a/torchtune/models/clip/__init__.py b/torchtune/models/clip/__init__.py index 5361c8968c..a6ea4ecb6a 100644 --- a/torchtune/models/clip/__init__.py +++ b/torchtune/models/clip/__init__.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._component_builders import clip_mlp, clip_vision_encoder - +from ._component_builders import clip_mlp, clip_text_encoder, clip_vision_encoder +from ._model_builders import clip_text_vit_large_patch14, clip_tokenizer from ._position_embeddings import ( TiledTokenPositionalEmbedding, TilePositionalEmbedding, @@ -15,7 +15,10 @@ __all__ = [ "clip_mlp", + "clip_text_encoder", "clip_vision_encoder", + "clip_text_vit_large_patch14", + "clip_tokenizer", "CLIPImageTransform", "TokenPositionalEmbedding", "TiledTokenPositionalEmbedding", diff --git a/torchtune/models/clip/_component_builders.py b/torchtune/models/clip/_component_builders.py index 4bbe0ab7f7..c8d19aae41 100644 --- a/torchtune/models/clip/_component_builders.py +++ b/torchtune/models/clip/_component_builders.py @@ -8,12 +8,13 @@ from typing import Callable, List, Optional from torch import nn + from torchtune.models.clip._position_embeddings import ( TiledTokenPositionalEmbedding, TilePositionalEmbedding, TokenPositionalEmbedding, ) - +from torchtune.models.clip._text_encoder import CLIPTextEncoder, QuickGELU from torchtune.modules import ( FeedForward, Fp32LayerNorm, @@ -22,11 +23,8 @@ TransformerSelfAttentionLayer, VisionRotaryPositionalEmbeddings, ) - from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook - -from torchtune.modules.peft import DoRALinear, LORA_ATTN_MODULES, LoRALinear - +from torchtune.modules.peft import LORA_ATTN_MODULES, DoRALinear, LoRALinear from torchtune.modules.vision_transformer import CLSProjection, VisionTransformer @@ -181,6 +179,65 @@ def clip_vision_encoder( ) +def clip_text_encoder( + embed_dim: int, + num_heads: int, + num_layers: int, + vocab_size: int = 49408, + max_seq_len: int = 77, + norm_eps: float = 1e-5, +): + """ + Text encoder for CLIP. + + CLIP is a model that encodes text and images into a shared vector space. + Blog post: https://openai.com/index/clip/ + Paper: https://arxiv.org/abs/2103.00020 + + Args: + embed_dim (int): embedding/model dimension size + num_heads (int): number of attention heads + num_layers (int): number of transformer layers + vocab_size (int): size of the vocabulary, default 49408 + max_seq_len (int): context size, default 77 + norm_eps (float): small value added to denominator for numerical stability, default 1e-5 + + Returns: + CLIPTextEncoder + """ + attn = MultiHeadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_heads, + head_dim=embed_dim // num_heads, + q_proj=nn.Linear(embed_dim, embed_dim), + k_proj=nn.Linear(embed_dim, embed_dim), + v_proj=nn.Linear(embed_dim, embed_dim), + output_proj=nn.Linear(embed_dim, embed_dim), + ) + mlp = clip_mlp( + in_dim=embed_dim, + out_dim=embed_dim, + hidden_dim=embed_dim * 4, + activation=QuickGELU(), + ) + encoder_layer = TransformerSelfAttentionLayer( + attn=attn, + mlp=mlp, + sa_norm=nn.LayerNorm(embed_dim, eps=norm_eps), + mlp_norm=nn.LayerNorm(embed_dim, eps=norm_eps), + ) + final_norm = nn.LayerNorm(embed_dim, eps=norm_eps) + return CLIPTextEncoder( + layer=encoder_layer, + final_norm=final_norm, + vocab_size=vocab_size, + max_seq_len=max_seq_len, + embed_dim=embed_dim, + num_layers=num_layers, + ) + + def clip_mlp( in_dim: int, out_dim: int, diff --git a/torchtune/models/clip/_convert_weights.py b/torchtune/models/clip/_convert_weights.py new file mode 100644 index 0000000000..004100e928 --- /dev/null +++ b/torchtune/models/clip/_convert_weights.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtune.models.convert_weights import get_mapped_key + +# state dict key mappings from HF's format to torchtune's format +_FROM_HF = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embedding", + "text_model.encoder.layers.{}.layer_norm1.weight": "layers.{}.sa_norm.weight", + "text_model.encoder.layers.{}.layer_norm1.bias": "layers.{}.sa_norm.bias", + "text_model.encoder.layers.{}.layer_norm2.weight": "layers.{}.mlp_norm.weight", + "text_model.encoder.layers.{}.layer_norm2.bias": "layers.{}.mlp_norm.bias", + "text_model.encoder.layers.{}.mlp.fc1.weight": "layers.{}.mlp.w1.weight", + "text_model.encoder.layers.{}.mlp.fc1.bias": "layers.{}.mlp.w1.bias", + "text_model.encoder.layers.{}.mlp.fc2.weight": "layers.{}.mlp.w2.weight", + "text_model.encoder.layers.{}.mlp.fc2.bias": "layers.{}.mlp.w2.bias", + "text_model.encoder.layers.{}.self_attn.q_proj.weight": "layers.{}.attn.q_proj.weight", + "text_model.encoder.layers.{}.self_attn.q_proj.bias": "layers.{}.attn.q_proj.bias", + "text_model.encoder.layers.{}.self_attn.k_proj.weight": "layers.{}.attn.k_proj.weight", + "text_model.encoder.layers.{}.self_attn.k_proj.bias": "layers.{}.attn.k_proj.bias", + "text_model.encoder.layers.{}.self_attn.v_proj.weight": "layers.{}.attn.v_proj.weight", + "text_model.encoder.layers.{}.self_attn.v_proj.bias": "layers.{}.attn.v_proj.bias", + "text_model.encoder.layers.{}.self_attn.out_proj.bias": "layers.{}.attn.output_proj.bias", + "text_model.encoder.layers.{}.self_attn.out_proj.weight": "layers.{}.attn.output_proj.weight", + "text_model.final_layer_norm.weight": "final_norm.weight", + "text_model.final_layer_norm.bias": "final_norm.bias", +} + +_IGNORE = { + "logit_scale", + "text_model.embeddings.position_ids", + "text_projection.weight", + "visual_projection.weight", +} + + +def clip_text_hf_to_tune(state_dict): + converted_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("vision_model.") or key in _IGNORE: + continue + new_key = get_mapped_key(key, _FROM_HF) + converted_state_dict[new_key] = value + return converted_state_dict diff --git a/torchtune/models/clip/_model_builders.py b/torchtune/models/clip/_model_builders.py index 61c01d1c51..d640466def 100644 --- a/torchtune/models/clip/_model_builders.py +++ b/torchtune/models/clip/_model_builders.py @@ -1,4 +1,54 @@ -from torchtune.models.clip._transforms import CLIPImageTransform +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torchtune.models.clip._component_builders import clip_text_encoder +from torchtune.models.clip._text_encoder import CLIPTextEncoder +from torchtune.models.clip._tokenizer import CLIPTokenizer +from torchtune.models.clip._transform import CLIPImageTransform + + +def clip_tokenizer( + path: str, + max_seq_len: int = 77, + truncate: bool = True, +) -> CLIPTokenizer: + """ + Builder for the CLIP text tokenizer. + + Args: + path (str): Path to the CLIP merges file + max_seq_len (bool): Context length. Default: 77 + truncate (bool): Truncate the token sequence if it exceeds max_seq_len (otherwise raises AssertionError) + Default: True + + Returns: + CLIPTokenizer: Instantiation of the CLIP text tokenizer + """ + return CLIPTokenizer(path, max_seq_len=max_seq_len, truncate=truncate) + + +def clip_text_vit_large_patch14() -> CLIPTextEncoder: + """ + Builder for the CLIP text encoder for CLIP-ViT-L/14. + + CLIP is a model that encodes text and images into a shared vector space. + Blog post: https://openai.com/index/clip/ + Paper: https://arxiv.org/abs/2103.00020 + + Returns: + CLIPTextEncoder: Instantiation of the CLIP text encoder + """ + return clip_text_encoder( + embed_dim=768, + num_heads=12, + num_layers=12, + vocab_size=49408, + max_seq_len=77, + norm_eps=1e-5, + ) + def clip_vit_224_transform(): image_transform = CLIPImageTransform( diff --git a/torchtune/models/clip/_text_encoder.py b/torchtune/models/clip/_text_encoder.py new file mode 100644 index 0000000000..1bfb756664 --- /dev/null +++ b/torchtune/models/clip/_text_encoder.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +from typing import Optional + +import torch +from torch import nn, Tensor + +from torchtune.modules.attention_utils import _MaskType + + +class CLIPTextEncoder(nn.Module): + """ + Text encoder for CLIP. + + CLIP is a model that encodes text and images into a shared vector space. + Blog post: https://openai.com/index/clip/ + Paper: https://arxiv.org/abs/2103.00020 + + Args: + layer (nn.Module): A single encoder layer. + final_norm (nn.Module): Callable that applies normalization to the output of the encoder + vocab_size (int): size of the vocabulary, default 49408 + max_seq_len (int): context size, default 77 + embed_dim (int): embedding/model dimension size, default 768 + num_layers (int): number of transformer layers, default 12 + eot_token (int): the id of the end-of-text token (for selecting the final output) + """ + + def __init__( + self, + *, + layer: nn.Module, + final_norm: nn.Module, + vocab_size: int = 49408, + max_seq_len: int = 77, + embed_dim: int = 768, + num_layers: int = 12, + eot_token: int = 49407, + ): + super().__init__() + self.layers = nn.ModuleList([copy.deepcopy(layer) for i in range(num_layers)]) + self.final_norm = final_norm + self.max_seq_len = max_seq_len + self.eot_token = eot_token + + self.token_embedding = nn.Embedding(vocab_size, embed_dim) + self.position_embedding = nn.Parameter(torch.empty(max_seq_len, embed_dim)) + + def forward( + self, + tokens: Tensor, + *, + mask: Optional[_MaskType] = None, + ) -> Tensor: + """ + Args: + tokens (Tensor): input tensor with shape ``[b x s]`` + mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication + and before the softmax. + Default is None. + + Returns: + Tensor: output tensor with shape [b x d] + + Raises: + ValueError: if seq_len of tokens is bigger than max_seq_len + + Shape notation: + - b: batch size + - s: token sequence length + - d: token embed dim + """ + # Input validation + bsz, seq_len = tokens.shape + if seq_len > self.max_seq_len: + raise ValueError( + f"seq_len ({seq_len}) of input tensor should be smaller " + f"than max_seq_len ({self.max_seq_len})" + ) + + # Input embedding [b, s] -> [b, s, d] + x = self.token_embedding(tokens) + self.position_embedding + + # Encoder [b, s, d] -> [b, s, d] + for layer in self.layers: + x = layer( + x, + mask=mask, + ) + x = self.final_norm(x) + + # Select the output of the EOT token for each encoding in the batch + # [b, s, d] -> [b, d] + eos_token_positions = (tokens == self.eot_token).int().argmax(dim=-1) + x = x.take_along_dim(eos_token_positions.view(-1, 1, 1), dim=1).squeeze(dim=1) + + return x + + +class QuickGELU(nn.Module): + """ + Fast approximation of GELU. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) diff --git a/torchtune/models/clip/_tokenizer.py b/torchtune/models/clip/_tokenizer.py new file mode 100644 index 0000000000..69fed32c72 --- /dev/null +++ b/torchtune/models/clip/_tokenizer.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Any, Dict, List, Mapping, Set, Tuple + +import regex as re + +from torchtune.modules.tokenizers._utils import BaseTokenizer + +WORD_BOUNDARY = "" + + +class CLIPTokenizer(BaseTokenizer): + """ + Text tokenizer for CLIP. + + Based on the official implementation here: + https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py + + Args: + path (str): the path to the CLIP merges file + max_seq_len (int): the context length (all CLIP models use 77) + truncate (bool): whether to truncate the text when longer than max_seq_len + """ + + def __init__(self, path: str, max_seq_len: int = 77, truncate: bool = True): + self.max_seq_len = max_seq_len + self.truncate = truncate + + self.byte_encoder = _bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + + merges = _load_merges(path) + + vocab = list(self.byte_encoder.values()) + vocab.extend([v + WORD_BOUNDARY for v in vocab]) + vocab.extend(["".join(merge) for merge in merges]) + vocab.extend(["<|startoftext|>", "<|endoftext|>"]) + + self.encoder = {word: i for i, word in enumerate(vocab)} + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = {merge: i for i, merge in enumerate(merges)} + + self.pat = re.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + self.sot_token = self.encoder["<|startoftext|>"] + self.eot_token = self.encoder["<|endoftext|>"] + self.pad_token = self.eot_token + + self.cache = { + "<|startoftext|>": "<|startoftext|>", + "<|endoftext|>": "<|endoftext|>", + } + + def encode(self, text: str) -> List[int]: + """ + Given a string, return the encoded list of token ids. + + Args: + text (str): The text to encode. + + Returns: + List[int]: The encoded list of token ids. + """ + text = _clean_text(text).lower() + + tokens = [self.sot_token] + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + tokens.extend( + self.encoder[bpe_token] for bpe_token in self._bpe(token).split(" ") + ) + if len(tokens) >= self.max_seq_len: + break + tokens.append(self.eot_token) + + if len(tokens) > self.max_seq_len: + assert self.truncate, ( + "Tokenized text is larger than the maximum sequence length but " + "truncate is set to False." + ) + tokens = tokens[: self.max_seq_len] + tokens[-1] = self.eot_token + + return tokens + + def decode(self, tokens: List[int]) -> str: + """ + Given a list of token ids, return the decoded text, optionally including special tokens. + + Args: + tokens (List[int]): The list of token ids to decode. + + Returns: + str: The decoded text. + """ + text = "".join([self.decoder[token] for token in tokens]) + return ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace(WORD_BOUNDARY, " ") + ) + + def __call__( + self, sample: Mapping[str, Any], inference: bool = False + ) -> Mapping[str, Any]: + """ + Tokenize the "text" field in the sample. + + Args: + sample (Mapping[str, Any]): A sample with a "text" field containing a string to tokenize + inference (bool): Unused by this tokenizer + + Returns: + Mapping[str, Any]: The sample with added "tokens" field and the "messages" field removed. + """ + text = sample.pop("text") + sample["tokens"] = self.encode(text) + return sample + + def _bpe(self, token: str) -> str: + """ + Performs byte-pair encoding on a single token. + """ + if token in self.cache: + return self.cache[token] + + if len(token) < 2: + return token + WORD_BOUNDARY + + # create the initial "word" (seq of "symbols" i.e. characters and merged subwords) + # by converting the token to tuple of characters and add to the last character + word = tuple(token[:-1]) + (token[-1] + WORD_BOUNDARY,) + + # get all pairs of adjacent characters + pairs = _get_pairs(word) + + # merge symbol pairs until there are no possible merges left + while True: + # find the pair with the lowest rank (highest priority to merge) + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + # end if there are no pairs to merge + if bigram not in self.bpe_ranks: + break + + # create the next "word" by merging any adjacent symbols that match the bigram + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + # find next potentially mergeable position and copy over any skipped characters + # if no more merge positions found, copy remaining characters and finish + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except ValueError: + new_word.extend(word[i:]) + break + + # check if we can perform a merge + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + word = tuple(new_word) + + # end if the new "word" is fully merged + if len(word) == 1: + break + + # get all pairs of adjacent symbols in the new word + pairs = _get_pairs(word) + + word = " ".join(word) + self.cache[token] = word + return word + + +def _bytes_to_unicode() -> Dict[int, str]: + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def _get_pairs(word: Tuple[str, ...]) -> Set[Tuple[str, str]]: + """ + Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def _clean_text(text: str) -> str: + """ + Minimal version of CLIP's text cleaning via the `ftfy` package. + """ + return text.replace("’", "'") + + +def _load_merges(path: str) -> List[Tuple[str, str]]: + merges = [] + with open(path, encoding="utf-8") as f: + for i, line in enumerate(f): + line = line.strip() + if (i == 0 and line.startswith("#version:")) or not line: + continue + merges.append(tuple(line.split())) + return merges diff --git a/torchtune/training/checkpointing/_checkpointer.py b/torchtune/training/checkpointing/_checkpointer.py index ec79f0a4ba..fcb4bd131e 100644 --- a/torchtune/training/checkpointing/_checkpointer.py +++ b/torchtune/training/checkpointing/_checkpointer.py @@ -7,15 +7,15 @@ import gc import json import os - from pathlib import Path from typing import Any, Dict, List, Optional, Protocol, Union import torch from safetensors.torch import save_file -from torchtune import training +from torchtune import training from torchtune.models import convert_weights +from torchtune.models.clip._convert_weights import clip_text_hf_to_tune from torchtune.models.phi3._convert_weights import phi3_hf_to_tune, phi3_tune_to_hf from torchtune.models.qwen2._convert_weights import qwen2_hf_to_tune, qwen2_tune_to_hf from torchtune.rlhf.utils import reward_hf_to_tune, reward_tune_to_hf @@ -488,6 +488,10 @@ def load_checkpoint(self) -> Dict[str, Any]: "supported_aspect_ratios", None ), ) + elif self._model_type == ModelType.CLIP_TEXT: + converted_state_dict[training.MODEL_KEY] = clip_text_hf_to_tune( + merged_state_dict, + ) elif self._model_type == ModelType.GEMMA2: from torchtune.models.gemma2._convert_weights import gemma2_hf_to_tune diff --git a/torchtune/training/checkpointing/_utils.py b/torchtune/training/checkpointing/_utils.py index 2fa7265194..e6a7b1afa1 100644 --- a/torchtune/training/checkpointing/_utils.py +++ b/torchtune/training/checkpointing/_utils.py @@ -56,6 +56,7 @@ class ModelType(Enum): to a single class for reward modelling. See :func:`~torchtune.models.mistral.mistral_reward_7b` or :func:`~torchtune.models.llama2.llama2_reward_7b` QWEN2 (str): Qwen2 family of models. See :func:`~torchtune.models.qwen2.qwen2` + CLIP_TEXT (str): CLIP text encoder. See :func:`~torchtune.models.clip.clip_text_encoder_large` Example: >>> # Usage in a checkpointer class @@ -75,6 +76,7 @@ class ModelType(Enum): PHI3_MINI: str = "phi3_mini" REWARD: str = "reward" QWEN2: str = "qwen2" + CLIP_TEXT: str = "clip_text" class FormattedCheckpointFiles: