|
3 | 3 | from transformers import CLIPTokenizer
|
4 | 4 | from iree import runtime as ireert
|
5 | 5 | import torch
|
| 6 | +from PIL import Image |
6 | 7 |
|
7 | 8 | parser = argparse.ArgumentParser()
|
8 | 9 |
|
@@ -52,49 +53,125 @@ def run_clip(
|
52 | 53 | ):
|
53 | 54 | runner = vmfbRunner(device, vmfb_path, external_weight_path)
|
54 | 55 |
|
55 |
| - tokenizer = CLIPTokenizer.from_pretrained( |
56 |
| - hf_model_name, |
57 |
| - subfolder="tokenizer", |
58 |
| - token=hf_auth_token, |
59 |
| - ) |
60 |
| - text_input = tokenizer( |
61 |
| - prompt, |
62 |
| - padding="max_length", |
63 |
| - max_length=tokenizer.model_max_length, |
64 |
| - truncation=True, |
65 |
| - return_tensors="pt", |
66 |
| - ) |
| 56 | + if "google/t5" in hf_model_name: |
| 57 | + from transformers import T5Tokenizer, T5Model |
| 58 | + |
| 59 | + tokenizer = T5Tokenizer.from_pretrained(hf_model_name) |
| 60 | + text_input = tokenizer( |
| 61 | + prompt, |
| 62 | + padding="max_length", |
| 63 | + max_length=tokenizer.model_max_length, |
| 64 | + truncation=True, |
| 65 | + return_tensors="pt", |
| 66 | + ) |
| 67 | + # TODO: Integrate with HFTransformerBuilder |
| 68 | + else: |
| 69 | + if "openai" in hf_model_name: |
| 70 | + from transformers import CLIPProcessor |
| 71 | + import requests |
| 72 | + |
| 73 | + url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
| 74 | + image = Image.open(requests.get(url, stream=True).raw) |
| 75 | + tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |
| 76 | + text_input = tokenizer( |
| 77 | + text=prompt, |
| 78 | + images=image, |
| 79 | + truncation=True, |
| 80 | + padding=True, |
| 81 | + return_tensors="pt", |
| 82 | + ) |
| 83 | + else: |
| 84 | + hf_subfolder = "tokenizer" |
| 85 | + |
| 86 | + tokenizer = CLIPTokenizer.from_pretrained( |
| 87 | + hf_model_name, |
| 88 | + subfolder=hf_subfolder, |
| 89 | + token=hf_auth_token, |
| 90 | + ) |
| 91 | + |
| 92 | + text_input = tokenizer( |
| 93 | + prompt, |
| 94 | + padding="max_length", |
| 95 | + max_length=tokenizer.model_max_length, |
| 96 | + truncation=True, |
| 97 | + return_tensors="pt", |
| 98 | + ) |
67 | 99 | example_input = text_input.input_ids
|
68 | 100 | inp = [ireert.asdevicearray(runner.config.device, example_input)]
|
69 | 101 |
|
| 102 | + if "google/t5" in hf_model_name: |
| 103 | + inp += [ireert.asdevicearray(runner.config.device, example_input)] |
70 | 104 | results = runner.ctx.modules.compiled_clip["main"](*inp)
|
71 | 105 | return results
|
72 | 106 |
|
73 | 107 |
|
74 | 108 | def run_torch_clip(hf_model_name, hf_auth_token, prompt):
|
| 109 | + if "google/t5" in hf_model_name: |
| 110 | + from transformers import T5Tokenizer, T5Model |
| 111 | + |
| 112 | + tokenizer = T5Tokenizer.from_pretrained(hf_model_name) |
| 113 | + model = T5Model.from_pretrained(hf_model_name) |
| 114 | + text_input = tokenizer( |
| 115 | + prompt, |
| 116 | + padding="max_length", |
| 117 | + max_length=tokenizer.model_max_length, |
| 118 | + truncation=True, |
| 119 | + return_tensors="pt", |
| 120 | + ) |
75 | 121 | # TODO: Integrate with HFTransformerBuilder
|
76 |
| - from transformers import CLIPTextModel |
| 122 | + else: |
| 123 | + if hf_model_name == "openai/clip-vit-large-patch14": |
| 124 | + from transformers import CLIPProcessor |
| 125 | + import requests |
77 | 126 |
|
78 |
| - model = CLIPTextModel.from_pretrained( |
79 |
| - hf_model_name, |
80 |
| - subfolder="text_encoder", |
81 |
| - token=hf_auth_token, |
82 |
| - ) |
83 |
| - tokenizer = CLIPTokenizer.from_pretrained( |
84 |
| - hf_model_name, |
85 |
| - subfolder="tokenizer", |
86 |
| - token=hf_auth_token, |
87 |
| - ) |
88 |
| - text_input = tokenizer( |
89 |
| - prompt, |
90 |
| - padding="max_length", |
91 |
| - max_length=tokenizer.model_max_length, |
92 |
| - truncation=True, |
93 |
| - return_tensors="pt", |
94 |
| - ) |
| 127 | + url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
| 128 | + image = Image.open(requests.get(url, stream=True).raw) |
| 129 | + |
| 130 | + tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") |
| 131 | + hf_subfolder = "" # CLIPProcessor does not have a subfolder |
| 132 | + from transformers import CLIPTextModel |
| 133 | + |
| 134 | + model = CLIPTextModel.from_pretrained( |
| 135 | + hf_model_name, |
| 136 | + subfolder=hf_subfolder, |
| 137 | + token=hf_auth_token, |
| 138 | + ) |
| 139 | + text_input = tokenizer( |
| 140 | + text=prompt, |
| 141 | + images=image, |
| 142 | + truncation=True, |
| 143 | + padding=True, |
| 144 | + return_tensors="pt", |
| 145 | + ) |
| 146 | + else: |
| 147 | + hf_subfolder = "text_encoder" |
| 148 | + |
| 149 | + tokenizer = CLIPTokenizer.from_pretrained( |
| 150 | + hf_model_name, |
| 151 | + subfolder="tokenizer", |
| 152 | + token=hf_auth_token, |
| 153 | + ) |
| 154 | + |
| 155 | + from transformers import CLIPTextModel |
| 156 | + |
| 157 | + model = CLIPTextModel.from_pretrained( |
| 158 | + hf_model_name, |
| 159 | + subfolder=hf_subfolder, |
| 160 | + token=hf_auth_token, |
| 161 | + ) |
| 162 | + text_input = tokenizer( |
| 163 | + prompt, |
| 164 | + padding="max_length", |
| 165 | + max_length=tokenizer.model_max_length, |
| 166 | + truncation=True, |
| 167 | + return_tensors="pt", |
| 168 | + ) |
95 | 169 | example_input = text_input.input_ids
|
96 | 170 |
|
97 |
| - results = model.forward(example_input)[0] |
| 171 | + if "google/t5" in hf_model_name: |
| 172 | + results = model.forward(example_input, decoder_input_ids=example_input)[0] |
| 173 | + else: |
| 174 | + results = model.forward(example_input)[0] |
98 | 175 | np_torch_output = results.detach().cpu().numpy()
|
99 | 176 | return np_torch_output
|
100 | 177 |
|
|
0 commit comments