Skip to content

Commit f2c025c

Browse files
authored
Clip_vit_large14 and t5 models (#535)
1 parent 6e3adb3 commit f2c025c

File tree

3 files changed

+341
-145
lines changed

3 files changed

+341
-145
lines changed

models/turbine_models/custom_models/sd_inference/clip.py

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,13 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
import os
8-
import sys
8+
import re
99

10-
from iree import runtime as ireert
11-
import iree.compiler as ireec
1210
from iree.compiler.ir import Context
13-
import numpy as np
1411
from shark_turbine.aot import *
1512
from turbine_models.custom_models.sd_inference import utils
1613
import torch
17-
import torch._dynamo as dynamo
18-
from transformers import CLIPTextModel, CLIPTokenizer
14+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor
1915
from turbine_models.turbine_tank import turbine_tank
2016

2117
import argparse
@@ -60,37 +56,77 @@ def export_clip_model(
6056
max_alloc=None,
6157
upload_ir=False,
6258
):
63-
# Load the tokenizer and text encoder to tokenize and encode the text.
64-
tokenizer = CLIPTokenizer.from_pretrained(
65-
hf_model_name,
66-
subfolder="tokenizer",
67-
token=hf_auth_token,
68-
)
59+
input_len = 77
60+
if "google/t5" in hf_model_name:
61+
from transformers import T5Tokenizer, T5Model
6962

70-
text_encoder_model = CLIPTextModel.from_pretrained(
71-
hf_model_name,
72-
subfolder="text_encoder",
73-
token=hf_auth_token,
74-
)
63+
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
64+
text_encoder_model = T5Model.from_pretrained(hf_model_name)
65+
input_len = 512
66+
67+
else:
68+
# TODO: Add better filtering mechanism for things that require CLIPProcessor
69+
if "openai" in hf_model_name:
70+
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
71+
hf_subfolder = "" # CLIPProcessor does not have a subfolder
72+
input_len = 10
73+
else:
74+
# Load the tokenizer and text encoder to tokenize and encode the text.
75+
tokenizer = CLIPTokenizer.from_pretrained(
76+
hf_model_name,
77+
subfolder="tokenizer",
78+
token=hf_auth_token,
79+
)
80+
hf_subfolder = "text_encoder"
81+
82+
text_encoder_model = CLIPTextModel.from_pretrained(
83+
hf_model_name,
84+
subfolder=hf_subfolder,
85+
token=hf_auth_token,
86+
)
7587

7688
mapper = {}
7789
utils.save_external_weights(
7890
mapper, text_encoder_model, external_weights, external_weight_path
7991
)
8092

81-
class CompiledClip(CompiledModule):
82-
if external_weights:
83-
params = export_parameters(
84-
text_encoder_model,
85-
external=True,
86-
external_scope="",
87-
name_mapper=mapper.get,
88-
)
89-
else:
90-
params = export_parameters(text_encoder_model)
93+
if "google/t5" in hf_model_name:
94+
95+
class CompiledClip(CompiledModule):
96+
if external_weights:
97+
params = export_parameters(
98+
text_encoder_model,
99+
external=True,
100+
external_scope="",
101+
name_mapper=mapper.get,
102+
)
103+
else:
104+
params = export_parameters(text_encoder_model)
105+
106+
def main(
107+
self,
108+
inp=AbstractTensor(1, input_len, dtype=torch.int64),
109+
decoder_input_ids=AbstractTensor(1, input_len, dtype=torch.int64),
110+
):
111+
return jittable(text_encoder_model.forward)(
112+
input_ids=inp, decoder_input_ids=decoder_input_ids
113+
)
114+
115+
else:
116+
117+
class CompiledClip(CompiledModule):
118+
if external_weights:
119+
params = export_parameters(
120+
text_encoder_model,
121+
external=True,
122+
external_scope="",
123+
name_mapper=mapper.get,
124+
)
125+
else:
126+
params = export_parameters(text_encoder_model)
91127

92-
def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
93-
return jittable(text_encoder_model.forward)(inp)
128+
def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)):
129+
return jittable(text_encoder_model.forward)(input_ids=inp)
94130

95131
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
96132
inst = CompiledClip(context=Context(), import_to=import_to)

models/turbine_models/custom_models/sd_inference/clip_runner.py

Lines changed: 108 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from transformers import CLIPTokenizer
44
from iree import runtime as ireert
55
import torch
6+
from PIL import Image
67

78
parser = argparse.ArgumentParser()
89

@@ -52,49 +53,125 @@ def run_clip(
5253
):
5354
runner = vmfbRunner(device, vmfb_path, external_weight_path)
5455

55-
tokenizer = CLIPTokenizer.from_pretrained(
56-
hf_model_name,
57-
subfolder="tokenizer",
58-
token=hf_auth_token,
59-
)
60-
text_input = tokenizer(
61-
prompt,
62-
padding="max_length",
63-
max_length=tokenizer.model_max_length,
64-
truncation=True,
65-
return_tensors="pt",
66-
)
56+
if "google/t5" in hf_model_name:
57+
from transformers import T5Tokenizer, T5Model
58+
59+
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
60+
text_input = tokenizer(
61+
prompt,
62+
padding="max_length",
63+
max_length=tokenizer.model_max_length,
64+
truncation=True,
65+
return_tensors="pt",
66+
)
67+
# TODO: Integrate with HFTransformerBuilder
68+
else:
69+
if "openai" in hf_model_name:
70+
from transformers import CLIPProcessor
71+
import requests
72+
73+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
74+
image = Image.open(requests.get(url, stream=True).raw)
75+
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
76+
text_input = tokenizer(
77+
text=prompt,
78+
images=image,
79+
truncation=True,
80+
padding=True,
81+
return_tensors="pt",
82+
)
83+
else:
84+
hf_subfolder = "tokenizer"
85+
86+
tokenizer = CLIPTokenizer.from_pretrained(
87+
hf_model_name,
88+
subfolder=hf_subfolder,
89+
token=hf_auth_token,
90+
)
91+
92+
text_input = tokenizer(
93+
prompt,
94+
padding="max_length",
95+
max_length=tokenizer.model_max_length,
96+
truncation=True,
97+
return_tensors="pt",
98+
)
6799
example_input = text_input.input_ids
68100
inp = [ireert.asdevicearray(runner.config.device, example_input)]
69101

102+
if "google/t5" in hf_model_name:
103+
inp += [ireert.asdevicearray(runner.config.device, example_input)]
70104
results = runner.ctx.modules.compiled_clip["main"](*inp)
71105
return results
72106

73107

74108
def run_torch_clip(hf_model_name, hf_auth_token, prompt):
109+
if "google/t5" in hf_model_name:
110+
from transformers import T5Tokenizer, T5Model
111+
112+
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
113+
model = T5Model.from_pretrained(hf_model_name)
114+
text_input = tokenizer(
115+
prompt,
116+
padding="max_length",
117+
max_length=tokenizer.model_max_length,
118+
truncation=True,
119+
return_tensors="pt",
120+
)
75121
# TODO: Integrate with HFTransformerBuilder
76-
from transformers import CLIPTextModel
122+
else:
123+
if hf_model_name == "openai/clip-vit-large-patch14":
124+
from transformers import CLIPProcessor
125+
import requests
77126

78-
model = CLIPTextModel.from_pretrained(
79-
hf_model_name,
80-
subfolder="text_encoder",
81-
token=hf_auth_token,
82-
)
83-
tokenizer = CLIPTokenizer.from_pretrained(
84-
hf_model_name,
85-
subfolder="tokenizer",
86-
token=hf_auth_token,
87-
)
88-
text_input = tokenizer(
89-
prompt,
90-
padding="max_length",
91-
max_length=tokenizer.model_max_length,
92-
truncation=True,
93-
return_tensors="pt",
94-
)
127+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
128+
image = Image.open(requests.get(url, stream=True).raw)
129+
130+
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
131+
hf_subfolder = "" # CLIPProcessor does not have a subfolder
132+
from transformers import CLIPTextModel
133+
134+
model = CLIPTextModel.from_pretrained(
135+
hf_model_name,
136+
subfolder=hf_subfolder,
137+
token=hf_auth_token,
138+
)
139+
text_input = tokenizer(
140+
text=prompt,
141+
images=image,
142+
truncation=True,
143+
padding=True,
144+
return_tensors="pt",
145+
)
146+
else:
147+
hf_subfolder = "text_encoder"
148+
149+
tokenizer = CLIPTokenizer.from_pretrained(
150+
hf_model_name,
151+
subfolder="tokenizer",
152+
token=hf_auth_token,
153+
)
154+
155+
from transformers import CLIPTextModel
156+
157+
model = CLIPTextModel.from_pretrained(
158+
hf_model_name,
159+
subfolder=hf_subfolder,
160+
token=hf_auth_token,
161+
)
162+
text_input = tokenizer(
163+
prompt,
164+
padding="max_length",
165+
max_length=tokenizer.model_max_length,
166+
truncation=True,
167+
return_tensors="pt",
168+
)
95169
example_input = text_input.input_ids
96170

97-
results = model.forward(example_input)[0]
171+
if "google/t5" in hf_model_name:
172+
results = model.forward(example_input, decoder_input_ids=example_input)[0]
173+
else:
174+
results = model.forward(example_input)[0]
98175
np_torch_output = results.detach().cpu().numpy()
99176
return np_torch_output
100177

0 commit comments

Comments
 (0)