Skip to content

lemonmindyes/ThinkCLIP

Repository files navigation

ThinkCLIP —— Vision-Language Contrastive Learning

A lightweight CLIP implementation that combines a Vision Transformer(VIT) image encoder and a GPT-style text encoder, trained with contrastive learning to align visual and language representations.

Model Overview

🖼 Vision Encoder - VIT

  • Patch embedding
  • Multi-layer Transformer
  • Flash attention support(is_flash = True)
  • The [CLS] token represents global image semantics

💬 Text Encoder — GPT

  • Token embedding + rotary positional embeddings
  • Multi-layer Transformer
  • The [EOS] token representation is used as the text embedding

🔗 Alignment Head

  • Linear projection to a shared embedding space (align_dim = 512)
  • Feature normalization and learnable temperature scaling
  • Contrastive loss objective:

L = 1/2 * [CE(logit_img, y) + CE(logit_text, y)]

⚙️Configuration

@dataclass
class Config:
    # common
    eps: float = 1e-5
    is_flash: bool = True
    align_dim: int = 512
    # img encoder
    img_size: int = 224
    patch_size: int = 16
    channel: int = 3
    img_dim: int = 384
    img_n_heads: int = 6
    img_n_layers: int = 12
    num_classes: int = None
    img_dropout_rate: float = 0.1
    # text encoder
    vocab_size: int = 6400
    max_seq_len: int = 64
    text_dim: int = 768
    text_n_heads: int = 12
    text_n_kv_heads: int = 3
    text_n_layers: int = 12
    text_dropout_rate: float = 0.1

Quick Start

Requirements

pip install -r requirements.txt

Train

datasets: LLaVA-Pretrain

python 1.train_pretrain.py

Eval

model checkpoint: ThinkCLIP

python 2.eval_clip.py

Loss Curve

red envelope

Evaluation

red envelope

text = [
    'Happy New Year' + '<|im_end|>',
    'China' + '<|im_end|>',
    'red' + '<|im_end|>',
    'envelope' + '<|im_end|>',
    'red envelope' + '<|im_end|>'
]

pred:
    Happy New Year: 1.71%
    China:          0.12%
    red:            9.02%
    envelope:       43.02%
    red envelope:   46.13%

red envelope

text = [
    'pants' + '<|im_end|>',
    'plane' + '<|im_end|>',
    'car' + '<|im_end|>',
    'blue shirt' + '<|im_end|>',
    'shoe' + '<|im_end|>',
    'Light blue pants' + '<|im_end|>',
    'Dark blue pants' + '<|im_end|>',
]

pred:
    pants:              1.18%
    plane:              0.00%
    car:                0.00%
    blue shirt:         1.42%
    shoe:               0.00%
    Light blue pants:   69.51%
    Dark blue pants:    27.88%

red envelope

text = [
    'girl playing mahjong' + '<|im_end|>',
    'boy playing mahjong' + '<|im_end|>',
    'girl' + '<|im_end|>',
    'mahjong' + '<|im_end|>',
    'green mahjong' + '<|im_end|>',
    'playing mahjong' + '<|im_end|>',
]

pred:
    girl playing mahjong:  80.13%
    boy playing mahjong:   19.68%
    girl:                  0.10%
    mahjong:               0.04%
    green mahjong:         0.00%
    playing mahjong:       0.06%

About

Lightweight CLIP framework built with ViT + GPT encoders for vision-language alignment.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages