Skip to content

Commit 5c87ac5

Browse files
authored
feat: add mindspore-version CLIP (totally rewrite based on openAI-CLIP) (#740)
1 parent 53131ed commit 5c87ac5

16 files changed

+4712
-0
lines changed

examples/clip/CLIP.png

247 KB
Loading

examples/clip/MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include clip/bpe_simple_vocab_16e6.txt.gz

examples/clip/README.md

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# CLIP
2+
3+
[[Blog]](https://openai.com/blog/clip/) [[Paper]](https://arxiv.org/abs/2103.00020) [[Model Card]](model-card.md) [[Colab]](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb)[[Source Code]](https://github.com/openai/CLIP)
4+
5+
CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision.
6+
7+
**Note: The original code of CLIP is built with Pytorch, whereas the CLIP in this repo is built with [MindSpore](https://www.mindspore.cn/). An API check list is on the road.**
8+
9+
## Approach
10+
11+
![CLIP](CLIP.png)
12+
13+
14+
## Usage
15+
16+
First, [install MindSpore 2.0.0](https://www.mindspore.cn/install) (or later), as well as small additional dependencies.
17+
18+
```bash
19+
$ cd ./examples/clip/
20+
$ pip install -r requirements.txt
21+
```
22+
23+
### Checkpoint Transform
24+
25+
Mindspore does not support .pt/.pth file, you are strongly recommended to transform the checkpoint file from OpenAI to a .ckpt format as follows (PyTorch needed):
26+
27+
```bash
28+
$ python ./clip/ckpt_transform.py --pth_path="ViT-B-32"
29+
```
30+
31+
Note: you can use both the model name and the local path of .pt/.pth file as a `pth_path`.
32+
33+
### Example Code
34+
35+
```python
36+
import clip
37+
from PIL import Image
38+
from mindspore import Tensor,nn
39+
40+
model, preprocess = clip.load("./ViT-B-32.ckpt", device="Ascend")
41+
42+
image = Tensor(preprocess(Image.open("CLIP.png")))
43+
text = clip.tokenize(["a diagram", "a dog", "a cat"])
44+
45+
image_features = model.encode_image(image)
46+
text_features = model.encode_text(text)
47+
48+
logits_per_image, logits_per_text = model(image, text)
49+
probs = nn.Softmax(axis=-1)(logits_per_image).numpy()
50+
51+
print("Label probs:", probs) # prints: [[0.9927937 0.00421067 0.00299571]]
52+
```
53+
54+
55+
## API
56+
57+
The CLIP module `clip` provides the following methods:
58+
59+
#### `clip.available_models()`
60+
61+
Returns the names of the available CLIP models.
62+
63+
#### `clip.load(name, device, mode, download_root)`
64+
65+
Returns the model and the transform operations needed to the image input, specified by argument `name`. It will download the model checkpoint as necessary.
66+
67+
Here's the argument comparison o`f `clip.load` in OpenAI-CLIP and MindSpore-CLIP (✅ means totally the same while ❌ represents not supported yet):
68+
69+
| OpenAI-CLIP | MindSpore-CLIP |
70+
| ------------------------------------------------------------ | ------------------------------------------------------------ |
71+
| **name** : str<br />A model name listed by `clip.available_models()`, or the path to a local model checkpoint containing the params_dict. ||
72+
| **device** : Union[str, torch.device] <br />The device to put the loaded model.<br />defalut: "cuda" if torch.cuda.is_available() else "cpu" | **device** : str<br />The device to put the loaded model, must be one of CPU, GPU, Ascend<br />default: "Ascend" |
73+
| **jit** : bool<br />Whether to load the optimized JIT model or more hackable non-JIT model (default). ||
74+
| **download_root** : str<br />Path to download the model files.<br />default: "~/.cache/clip" ||
75+
|| **mode** : int<br />GRAPH_MODE(0) or PYNATIVE_MODE(1).<br />default: 1 |
76+
77+
#### `clip.tokenize(text, context_length, truncate)`
78+
79+
Returns a tensor containing tokenized sequences of given text input(s), which can be used as the input of the model.
80+
81+
Here's the argument comparison of `clip.tokenize` in OpenAI-CLIP and MindSpore-CLIP:
82+
83+
| OpenAI-CLIP | MindSpore-CLIP |
84+
| ------------------------------------------------------------ | -------------- |
85+
| **texts** : Union[str, List[str]]<br />An input string or a list of input strings to tokenize. ||
86+
| **context_length** : int<br />The context length to use; all CLIP models use 77 as the default context length. ||
87+
| **truncate** : bool<br />Whether to truncate the text in case its encoding is longer than the context length.<br />default: False ||
88+
89+
#### model
90+
91+
The model returned by `clip.load()` supports the following methods:
92+
93+
#### `model.encode_image(image)`
94+
95+
Given a batch of images (Tensor), returns the image features (Tensor) encoded by the vision portion of the CLIP model.
96+
97+
#### `model.encode_text(text)`
98+
99+
Given a batch of text tokens (Tensor), returns the text features (Tensor) encoded by the language portion of the CLIP model.
100+
101+
#### `model(image, text)`
102+
103+
Given a batch of images (Tensor) and a batch of text tokens (Tensor), returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100.
104+
105+
106+
107+
## More Examples
108+
109+
### Zero-Shot Prediction
110+
111+
The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) (which has been integrated by MindSpore and hence we download it from MindSpore in this example), and predicts the most likely labels among the 100 textual labels from the dataset.
112+
113+
```python
114+
import clip
115+
from mindspore import ops, nn, Tensor
116+
import mindspore.dataset as ds
117+
from download import download
118+
from PIL import Image
119+
120+
# Load the model
121+
model, preprocess = clip.load("./ViT-B-32.ckpt", device="Ascend")
122+
123+
# Download the dataset
124+
cifar100_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-100-binary.tar.gz"
125+
download(cifar100_url, "./", kind="tar.gz", replace=True)
126+
cifar100_iter = ds.Cifar100Dataset("cifar-100-binary", usage="test", shuffle=False)
127+
cifar100=[]
128+
for i in cifar100_iter:
129+
cifar100.append([Image.fromarray(i[0].asnumpy()),int(i[2])])
130+
131+
# Prepare the inputs
132+
image, class_id = cifar100[3637]
133+
image_input = Tensor(preprocess(image))
134+
text_inputs = ops.cat([clip.tokenize(f"a photo of a {i[1]}") for i in cifar100])
135+
136+
# Calculate features
137+
image_features = model.encode_image(image_input)
138+
text_features = model.encode_text(text_inputs)
139+
140+
# Pick the top 5 most similar labels for the image
141+
image_features /= image_features.norm(dim=-1, keepdim=True)
142+
text_features /= text_features.norm(dim=-1, keepdim=True)
143+
similarity = nn.Softmax(axis=-1)(100.0 * image_features @ text_features.T)
144+
values, indices = similarity[0].topk(5)
145+
146+
# Print the result
147+
print("\nTop predictions:\n")
148+
index2label=[]
149+
with open('./cifar-100-binar/fine_label_names.txt','r') as f:
150+
for line in f:
151+
index2label.append(line.strip('\n'))
152+
for value, index in zip(values, indices):
153+
print(f"{index2label[index]:>16s}: {100 * float(value):.2f}%")
154+
```
155+
156+
The output will look like the following (the exact numbers may be slightly different depending on the compute device):
157+
158+
```
159+
Top predictions:
160+
161+
snake: 65.31%
162+
turtle: 12.29%
163+
sweet_pepper: 3.83%
164+
lizard: 1.88%
165+
crocodile: 1.75%
166+
```
167+
168+
Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs.
169+
170+
171+
### Linear-probe evaluation
172+
173+
The example below uses [scikit-learn](https://scikit-learn.org/) to perform logistic regression on image features.
174+
175+
```python
176+
import clip
177+
from mindspore import ops
178+
import mindspore.dataset as ds
179+
from download import download
180+
import numpy as np
181+
from sklearn.linear_model import LogisticRegression
182+
from tqdm import tqdm
183+
184+
# Load the model
185+
model, preprocess = clip.load("./ViT-B-32.ckpt", device="Ascend")
186+
187+
# Load the dataset
188+
cifar100_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-100-binary.tar.gz"
189+
download(cifar100_url, "./", kind="tar.gz", replace=True)
190+
cifar100_test_iter = ds.Cifar100Dataset("cifar-100-binary", usage="test", shuffle=False)
191+
cifar100_train_iter = ds.Cifar100Dataset("cifar-100-binary", usage="train", shuffle=False)
192+
cifar100_test_iter = cifar100_test_iter.map(preprocess, input_columns=["image"])
193+
cifar100_test_iter=cifar100_test_iter.batch(100)
194+
cifar100_train_iter = cifar100_train_iter.map(preprocess, input_columns=["image"])
195+
cifar100_train_iter=cifar100_train_iter.batch(100)
196+
197+
def get_features(dataset):
198+
all_features = []
199+
all_labels = []
200+
for images, _, labels in tqdm(dataset):
201+
features = model.encode_image(images)
202+
all_features.append(features)
203+
all_labels.append(labels)
204+
205+
return ops.cat(all_features).asnumpy(), ops.cat(all_labels).asnumpy()
206+
207+
208+
# Calculate the image features
209+
train_features, train_labels = get_features(cifar100_train_iter)
210+
test_features, test_labels = get_features(cifar100_test_iter)
211+
212+
# Perform logistic regression
213+
classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
214+
classifier.fit(train_features, train_labels)
215+
216+
# Evaluate using the logistic regression classifier
217+
predictions = classifier.predict(test_features)
218+
accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
219+
print(f"Accuracy = {accuracy:.3f}")
220+
```
221+
222+
Note that the `C` value should be determined via a hyperparameter sweep using a validation split.
223+
224+
225+
## See Also
226+
227+
* [OpenCLIP-MindSpore](https://github.com/mindspore-lab/mindcv/tree/main/examples/openclip/): includes larger and independently trained CLIP models up to ViT-G/14

examples/clip/ckpt_transform.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import argparse
2+
import os
3+
import sys
4+
5+
import torch
6+
7+
import mindspore as ms
8+
9+
from examples.clip.clip.clip import _MODELS, _download
10+
11+
12+
def parse_args(args):
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument(
15+
"--pth_path",
16+
type=str,
17+
default=None,
18+
help="Model name or the path of the model's checkpoint file given by OpenAI",
19+
)
20+
args = parser.parse_args(args)
21+
return args
22+
23+
24+
def pytorch_params(pth_file):
25+
par_dict = torch.load(pth_file, map_location="cpu").state_dict()
26+
pt_params = []
27+
for name in par_dict:
28+
parameter = par_dict[name]
29+
if "ln_" in name:
30+
name = name.replace(".weight", ".gamma").replace(".bias", ".beta")
31+
elif name == "token_embedding.weight":
32+
name = "token_embedding.embedding_table"
33+
elif ".bn" in name or ".downsample.1." in name:
34+
name = name.replace(".weight", ".gamma").replace(".bias", ".beta")
35+
name = name.replace(".running_mean", ".moving_mean").replace(".running_var", ".moving_variance")
36+
pt_params.append({"name": name, "data": ms.Tensor(parameter.numpy())})
37+
return pt_params
38+
39+
40+
def main(args):
41+
args = parse_args(args)
42+
if os.path.exists(args.pth_path):
43+
pt_param = pytorch_params(args.pth_path)
44+
ms.save_checkpoint(pt_param, args.pth_path.replace(".pt", ".ckpt"))
45+
elif args.pth_path in _MODELS.keys():
46+
model_path = _download(_MODELS[args.pth_path], os.path.expanduser("~/"))
47+
pt_param = pytorch_params(model_path)
48+
ms.save_checkpoint(pt_param, os.path.expanduser("~/"))
49+
else:
50+
raise ValueError(
51+
f"{args.pth_path} is not a supported checkpoint file or model name. "
52+
f"Models with available checkpoint file are: {list(_MODELS.keys())}"
53+
)
54+
print("Done!")
55+
56+
57+
if __name__ == "__main__":
58+
main(sys.argv[1:])

examples/clip/clip/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .clip import *
1.29 MB
Binary file not shown.

0 commit comments

Comments
 (0)