53
53
help = "Specify vulkan target triple or rocm/cuda target device." ,
54
54
)
55
55
parser .add_argument ("--vulkan_max_allocation" , type = str , default = "4294967296" )
56
+ parser .add_argument ('--controlled' , dest = 'controlled' , action = 'store_true' , help = "Whether or not to use controlled unet (for use with controlnet)" )
57
+ parser .add_argument ('--no-controlled' , dest = 'controlled' , action = 'store_false' , help = "Whether or not to use controlled unet (for use with controlnet)" )
58
+ parser .set_defaults (controlled = False )
56
59
57
60
58
61
class UnetModel (torch .nn .Module ):
59
- def __init__ (self , hf_model_name , hf_auth_token ):
62
+ def __init__ (self , hf_model_name , hf_auth_token , is_controlled ):
60
63
super ().__init__ ()
61
64
self .unet = UNet2DConditionModel .from_pretrained (
62
65
hf_model_name ,
63
66
subfolder = "unet" ,
64
67
token = hf_auth_token ,
65
68
)
66
69
self .guidance_scale = 7.5
70
+ if is_controlled :
71
+ self .forward = self .forward_controlled
72
+ else :
73
+ self .forward = self .forward_default
67
74
68
- def forward (self , sample , timestep , encoder_hidden_states ):
75
+ def forward_default (self , sample , timestep , encoder_hidden_states ):
69
76
samples = torch .cat ([sample ] * 2 )
70
77
unet_out = self .unet .forward (
71
78
samples , timestep , encoder_hidden_states , return_dict = False
@@ -76,6 +83,65 @@ def forward(self, sample, timestep, encoder_hidden_states):
76
83
)
77
84
return noise_pred
78
85
86
+ def forward_controlled (
87
+ self ,
88
+ sample ,
89
+ timestep ,
90
+ encoder_hidden_states ,
91
+ control1 ,
92
+ control2 ,
93
+ control3 ,
94
+ control4 ,
95
+ control5 ,
96
+ control6 ,
97
+ control7 ,
98
+ control8 ,
99
+ control9 ,
100
+ control10 ,
101
+ control11 ,
102
+ control12 ,
103
+ control13 ,
104
+ scale1 ,
105
+ scale2 ,
106
+ scale3 ,
107
+ scale4 ,
108
+ scale5 ,
109
+ scale6 ,
110
+ scale7 ,
111
+ scale8 ,
112
+ scale9 ,
113
+ scale10 ,
114
+ scale11 ,
115
+ scale12 ,
116
+ scale13 ,
117
+ ):
118
+ db_res_samples = tuple (
119
+ [
120
+ control1 * scale1 ,
121
+ control2 * scale2 ,
122
+ control3 * scale3 ,
123
+ control4 * scale4 ,
124
+ control5 * scale5 ,
125
+ control6 * scale6 ,
126
+ control7 * scale7 ,
127
+ control8 * scale8 ,
128
+ control9 * scale9 ,
129
+ control10 * scale10 ,
130
+ control11 * scale11 ,
131
+ control12 * scale12 ,
132
+ ]
133
+ )
134
+ mb_res_samples = control13 * scale13
135
+ samples = torch .cat ([sample ] * 2 )
136
+ unet_out = self .unet .forward (
137
+ samples , timestep , encoder_hidden_states , down_block_additional_residuals = db_res_samples , mid_block_additional_residual = mb_res_samples , return_dict = False
138
+ )[0 ]
139
+ noise_pred_uncond , noise_pred_text = unet_out .chunk (2 )
140
+ noise_pred = noise_pred_uncond + self .guidance_scale * (
141
+ noise_pred_text - noise_pred_uncond
142
+ )
143
+ return noise_pred
144
+
79
145
80
146
def export_unet_model (
81
147
unet_model ,
@@ -90,6 +156,7 @@ def export_unet_model(
90
156
device = None ,
91
157
target_triple = None ,
92
158
max_alloc = None ,
159
+ is_controlled = False ,
93
160
):
94
161
mapper = {}
95
162
utils .save_external_weights (
@@ -100,7 +167,7 @@ def export_unet_model(
100
167
if hf_model_name == "stabilityai/stable-diffusion-2-1-base" :
101
168
encoder_hidden_states_sizes = (2 , 77 , 1024 )
102
169
103
- sample = (batch_size , unet_model .unet .config .in_channels , height // 8 , width // 8 )
170
+ sample = (batch_size , unet_model .unet .config .in_channels , height , width )
104
171
105
172
class CompiledUnet (CompiledModule ):
106
173
if external_weights :
@@ -120,8 +187,85 @@ def main(
120
187
):
121
188
return jittable (unet_model .forward )(sample , timestep , encoder_hidden_states )
122
189
190
+ class CompiledControlledUnet (CompiledModule ):
191
+ if external_weights :
192
+ params = export_parameters (
193
+ unet_model , external = True , external_scope = "" , name_mapper = mapper .get
194
+ )
195
+ else :
196
+ params = export_parameters (unet_model )
197
+
198
+ def main (
199
+ self ,
200
+ sample = AbstractTensor (* sample , dtype = torch .float32 ),
201
+ timestep = AbstractTensor (1 , dtype = torch .float32 ),
202
+ encoder_hidden_states = AbstractTensor (
203
+ * encoder_hidden_states_sizes , dtype = torch .float32
204
+ ),
205
+ control1 = AbstractTensor (2 , 320 , height , width , dtype = torch .float32 ),
206
+ control2 = AbstractTensor (2 , 320 , height , width , dtype = torch .float32 ),
207
+ control3 = AbstractTensor (2 , 320 , height , width , dtype = torch .float32 ),
208
+ control4 = AbstractTensor (2 , 320 , height // 2 , width // 2 , dtype = torch .float32 ),
209
+ control5 = AbstractTensor (2 , 640 , height // 2 , width // 2 , dtype = torch .float32 ),
210
+ control6 = AbstractTensor (2 , 640 , height // 2 , width // 2 , dtype = torch .float32 ),
211
+ control7 = AbstractTensor (2 , 640 , height // 4 , width // 4 , dtype = torch .float32 ),
212
+ control8 = AbstractTensor (2 , 1280 , height // 4 , width // 4 , dtype = torch .float32 ),
213
+ control9 = AbstractTensor (2 , 1280 , height // 4 , width // 4 , dtype = torch .float32 ),
214
+ control10 = AbstractTensor (2 , 1280 , height // 8 , width // 8 , dtype = torch .float32 ),
215
+ control11 = AbstractTensor (2 , 1280 , height // 8 , width // 8 , dtype = torch .float32 ),
216
+ control12 = AbstractTensor (2 , 1280 , height // 8 , width // 8 , dtype = torch .float32 ),
217
+ control13 = AbstractTensor (2 , 1280 , height // 8 , width // 8 , dtype = torch .float32 ),
218
+ scale1 = AbstractTensor (1 , dtype = torch .float32 ),
219
+ scale2 = AbstractTensor (1 , dtype = torch .float32 ),
220
+ scale3 = AbstractTensor (1 , dtype = torch .float32 ),
221
+ scale4 = AbstractTensor (1 , dtype = torch .float32 ),
222
+ scale5 = AbstractTensor (1 , dtype = torch .float32 ),
223
+ scale6 = AbstractTensor (1 , dtype = torch .float32 ),
224
+ scale7 = AbstractTensor (1 , dtype = torch .float32 ),
225
+ scale8 = AbstractTensor (1 , dtype = torch .float32 ),
226
+ scale9 = AbstractTensor (1 , dtype = torch .float32 ),
227
+ scale10 = AbstractTensor (1 , dtype = torch .float32 ),
228
+ scale11 = AbstractTensor (1 , dtype = torch .float32 ),
229
+ scale12 = AbstractTensor (1 , dtype = torch .float32 ),
230
+ scale13 = AbstractTensor (1 , dtype = torch .float32 ),
231
+ ):
232
+ return jittable (unet_model .forward )(
233
+ sample ,
234
+ timestep ,
235
+ encoder_hidden_states ,
236
+ control1 ,
237
+ control2 ,
238
+ control3 ,
239
+ control4 ,
240
+ control5 ,
241
+ control6 ,
242
+ control7 ,
243
+ control8 ,
244
+ control9 ,
245
+ control10 ,
246
+ control11 ,
247
+ control12 ,
248
+ control13 ,
249
+ scale1 ,
250
+ scale2 ,
251
+ scale3 ,
252
+ scale4 ,
253
+ scale5 ,
254
+ scale6 ,
255
+ scale7 ,
256
+ scale8 ,
257
+ scale9 ,
258
+ scale10 ,
259
+ scale11 ,
260
+ scale12 ,
261
+ scale13 ,
262
+ )
263
+
123
264
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
124
- inst = CompiledUnet (context = Context (), import_to = import_to )
265
+ if is_controlled :
266
+ inst = CompiledControlledUnet (context = Context (), import_to = import_to )
267
+ else :
268
+ inst = CompiledUnet (context = Context (), import_to = import_to )
125
269
126
270
module_str = str (CompiledModule .get_mlir_module (inst ))
127
271
safe_name = utils .create_safe_name (hf_model_name , "-unet" )
@@ -134,8 +278,9 @@ def main(
134
278
if __name__ == "__main__" :
135
279
args = parser .parse_args ()
136
280
unet_model = UnetModel (
137
- args .hf_model_name ,
281
+ args .hf_model_name if not args . controlled else "CompVis/stable-diffusion-v1-4" ,
138
282
args .hf_auth_token ,
283
+ args .controlled ,
139
284
)
140
285
mod_str = export_unet_model (
141
286
unet_model ,
@@ -150,6 +295,7 @@ def main(
150
295
args .device ,
151
296
args .iree_target_triple ,
152
297
args .vulkan_max_allocation ,
298
+ args .controlled ,
153
299
)
154
300
safe_name = utils .create_safe_name (args .hf_model_name , "-unet" )
155
301
with open (f"{ safe_name } .mlir" , "w+" ) as f :
0 commit comments