Skip to content

Commit b078e6e

Browse files
committed
Add controlnet and controlled unet
1 parent c1dc94c commit b078e6e

File tree

2 files changed

+329
-5
lines changed

2 files changed

+329
-5
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright 2023 Nod Labs, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import os
8+
import sys
9+
10+
from iree import runtime as ireert
11+
from iree.compiler.ir import Context
12+
import numpy as np
13+
from shark_turbine.aot import *
14+
from turbine_models.custom_models.sd_inference import utils
15+
import torch
16+
import torch._dynamo as dynamo
17+
from diffusers import ControlNetModel as CNetModel
18+
19+
import safetensors
20+
import argparse
21+
import re
22+
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument(
25+
"--hf_auth_token", type=str, help="The Hugging Face auth token, required"
26+
)
27+
parser.add_argument(
28+
"--hf_model_name",
29+
type=str,
30+
help="HF model name",
31+
default="lllyasviel/control_v11p_sd15_canny",
32+
)
33+
parser.add_argument(
34+
"--batch_size", type=int, default=1, help="Batch size for inference"
35+
)
36+
parser.add_argument(
37+
"--height", type=int, default=512, help="Height of Stable Diffusion"
38+
)
39+
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
40+
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
41+
parser.add_argument("--external_weight_path", type=str, default="")
42+
parser.add_argument(
43+
"--external_weights",
44+
type=str,
45+
default=None,
46+
help="saves ir without global weights for size and readability, options [safetensors]",
47+
)
48+
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm")
49+
# TODO: Bring in detection for target triple
50+
parser.add_argument(
51+
"--iree_target_triple",
52+
type=str,
53+
default="",
54+
help="Specify vulkan target triple or rocm/cuda target device.",
55+
)
56+
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
57+
58+
59+
class ControlNetModel(torch.nn.Module):
60+
def __init__(
61+
self, model_id="lllyasviel/control_v11p_sd15_canny", low_cpu_mem_usage=False
62+
):
63+
super().__init__()
64+
self.cnet = CNetModel.from_pretrained(
65+
model_id,
66+
low_cpu_mem_usage=low_cpu_mem_usage,
67+
)
68+
self.in_channels = self.cnet.config.in_channels
69+
self.train(False)
70+
71+
def forward(
72+
self,
73+
latent,
74+
timestep,
75+
text_embedding,
76+
stencil_image_input,
77+
):
78+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
79+
# TODO: guidance NOT NEEDED change in `get_input_info` later
80+
latents = torch.cat([latent] * 2) # needs to be same as controlledUNET latents
81+
stencil_image = torch.cat(
82+
[stencil_image_input] * 2
83+
) # needs to be same as controlledUNET latents
84+
(
85+
down_block_res_samples,
86+
mid_block_res_sample,
87+
) = self.cnet.forward(
88+
latents,
89+
timestep,
90+
encoder_hidden_states=text_embedding,
91+
controlnet_cond=stencil_image,
92+
return_dict=False,
93+
)
94+
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
95+
96+
97+
def export_controlnet_model(
98+
controlnet_model,
99+
hf_model_name,
100+
batch_size,
101+
height,
102+
width,
103+
hf_auth_token=None,
104+
compile_to="torch",
105+
external_weights=None,
106+
external_weight_path=None,
107+
device=None,
108+
target_triple=None,
109+
max_alloc=None,
110+
):
111+
mapper = {}
112+
utils.save_external_weights(
113+
mapper, controlnet_model, external_weights, external_weight_path
114+
)
115+
116+
class CompiledControlnet(CompiledModule):
117+
if external_weights:
118+
params = export_parameters(
119+
controlnet_model,
120+
external=True,
121+
external_scope="",
122+
name_mapper=mapper.get,
123+
)
124+
else:
125+
params = export_parameters(controlnet_model)
126+
127+
def main(
128+
self,
129+
latent=AbstractTensor(1, 4, 512, 512, dtype=torch.float32),
130+
timestep=AbstractTensor(1, dtype=torch.float32),
131+
text_embedding=AbstractTensor(2, 72, 768, dtype=torch.float32),
132+
stencil_image_input=AbstractTensor(1, 3, 4096, 4096, dtype=torch.float32),
133+
):
134+
return jittable(controlnet_model.forward)(
135+
latent,
136+
timestep,
137+
text_embedding,
138+
stencil_image_input,
139+
)
140+
141+
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
142+
inst = CompiledControlnet(context=Context(), import_to=import_to)
143+
144+
module_str = str(CompiledModule.get_mlir_module(inst))
145+
safe_name = hf_model_name.split("/")[-1].strip()
146+
safe_name = re.sub("-", "_", safe_name)
147+
if compile_to != "vmfb":
148+
return module_str
149+
else:
150+
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
151+
152+
153+
if __name__ == "__main__":
154+
args = parser.parse_args()
155+
controlnet_model = ControlNetModel(
156+
args.hf_model_name,
157+
)
158+
mod_str = export_controlnet_model(
159+
controlnet_model,
160+
args.hf_model_name,
161+
args.batch_size,
162+
args.height,
163+
args.width,
164+
args.hf_auth_token,
165+
args.compile_to,
166+
args.external_weights,
167+
args.external_weight_path,
168+
args.device,
169+
args.iree_target_triple,
170+
args.vulkan_max_allocation,
171+
)
172+
173+
if mod_str is None:
174+
safe_name = args.hf_model_name.split("/")[-1].strip()
175+
safe_name = re.sub("-", "_", safe_name)
176+
with open(f"{safe_name}.mlir", "w+") as f:
177+
f.write(mod_str)
178+
print("Saved to", safe_name + ".mlir")

python/turbine_models/custom_models/sd_inference/unet.py

Lines changed: 151 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,26 @@
5353
help="Specify vulkan target triple or rocm/cuda target device.",
5454
)
5555
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
56+
parser.add_argument('--controlled', dest='controlled', action='store_true', help="Whether or not to use controlled unet (for use with controlnet)")
57+
parser.add_argument('--no-controlled', dest='controlled', action='store_false', help="Whether or not to use controlled unet (for use with controlnet)")
58+
parser.set_defaults(controlled=False)
5659

5760

5861
class UnetModel(torch.nn.Module):
59-
def __init__(self, hf_model_name, hf_auth_token):
62+
def __init__(self, hf_model_name, hf_auth_token, is_controlled):
6063
super().__init__()
6164
self.unet = UNet2DConditionModel.from_pretrained(
6265
hf_model_name,
6366
subfolder="unet",
6467
token=hf_auth_token,
6568
)
6669
self.guidance_scale = 7.5
70+
if is_controlled:
71+
self.forward = self.forward_controlled
72+
else:
73+
self.forward = self.forward_default
6774

68-
def forward(self, sample, timestep, encoder_hidden_states):
75+
def forward_default(self, sample, timestep, encoder_hidden_states):
6976
samples = torch.cat([sample] * 2)
7077
unet_out = self.unet.forward(
7178
samples, timestep, encoder_hidden_states, return_dict=False
@@ -76,6 +83,65 @@ def forward(self, sample, timestep, encoder_hidden_states):
7683
)
7784
return noise_pred
7885

86+
def forward_controlled(
87+
self,
88+
sample,
89+
timestep,
90+
encoder_hidden_states,
91+
control1,
92+
control2,
93+
control3,
94+
control4,
95+
control5,
96+
control6,
97+
control7,
98+
control8,
99+
control9,
100+
control10,
101+
control11,
102+
control12,
103+
control13,
104+
scale1,
105+
scale2,
106+
scale3,
107+
scale4,
108+
scale5,
109+
scale6,
110+
scale7,
111+
scale8,
112+
scale9,
113+
scale10,
114+
scale11,
115+
scale12,
116+
scale13,
117+
):
118+
db_res_samples = tuple(
119+
[
120+
control1 * scale1,
121+
control2 * scale2,
122+
control3 * scale3,
123+
control4 * scale4,
124+
control5 * scale5,
125+
control6 * scale6,
126+
control7 * scale7,
127+
control8 * scale8,
128+
control9 * scale9,
129+
control10 * scale10,
130+
control11 * scale11,
131+
control12 * scale12,
132+
]
133+
)
134+
mb_res_samples = control13 * scale13
135+
samples = torch.cat([sample] * 2)
136+
unet_out = self.unet.forward(
137+
samples, timestep, encoder_hidden_states, down_block_additional_residuals=db_res_samples, mid_block_additional_residual=mb_res_samples, return_dict=False
138+
)[0]
139+
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
140+
noise_pred = noise_pred_uncond + self.guidance_scale * (
141+
noise_pred_text - noise_pred_uncond
142+
)
143+
return noise_pred
144+
79145

80146
def export_unet_model(
81147
unet_model,
@@ -90,6 +156,7 @@ def export_unet_model(
90156
device=None,
91157
target_triple=None,
92158
max_alloc=None,
159+
is_controlled=False,
93160
):
94161
mapper = {}
95162
utils.save_external_weights(
@@ -100,7 +167,7 @@ def export_unet_model(
100167
if hf_model_name == "stabilityai/stable-diffusion-2-1-base":
101168
encoder_hidden_states_sizes = (2, 77, 1024)
102169

103-
sample = (batch_size, unet_model.unet.config.in_channels, height // 8, width // 8)
170+
sample = (batch_size, unet_model.unet.config.in_channels, height, width)
104171

105172
class CompiledUnet(CompiledModule):
106173
if external_weights:
@@ -120,8 +187,85 @@ def main(
120187
):
121188
return jittable(unet_model.forward)(sample, timestep, encoder_hidden_states)
122189

190+
class CompiledControlledUnet(CompiledModule):
191+
if external_weights:
192+
params = export_parameters(
193+
unet_model, external=True, external_scope="", name_mapper=mapper.get
194+
)
195+
else:
196+
params = export_parameters(unet_model)
197+
198+
def main(
199+
self,
200+
sample=AbstractTensor(*sample, dtype=torch.float32),
201+
timestep=AbstractTensor(1, dtype=torch.float32),
202+
encoder_hidden_states=AbstractTensor(
203+
*encoder_hidden_states_sizes, dtype=torch.float32
204+
),
205+
control1=AbstractTensor(2, 320, height, width, dtype=torch.float32),
206+
control2=AbstractTensor(2, 320, height, width, dtype=torch.float32),
207+
control3=AbstractTensor(2, 320, height, width, dtype=torch.float32),
208+
control4=AbstractTensor(2, 320, height//2, width//2, dtype=torch.float32),
209+
control5=AbstractTensor(2, 640, height//2, width//2, dtype=torch.float32),
210+
control6=AbstractTensor(2, 640, height//2, width//2, dtype=torch.float32),
211+
control7=AbstractTensor(2, 640, height//4, width//4, dtype=torch.float32),
212+
control8=AbstractTensor(2, 1280, height//4, width//4, dtype=torch.float32),
213+
control9=AbstractTensor(2, 1280, height//4, width//4, dtype=torch.float32),
214+
control10=AbstractTensor(2, 1280, height//8, width//8, dtype=torch.float32),
215+
control11=AbstractTensor(2, 1280, height//8, width//8, dtype=torch.float32),
216+
control12=AbstractTensor(2, 1280, height//8, width//8, dtype=torch.float32),
217+
control13=AbstractTensor(2, 1280, height//8, width//8, dtype=torch.float32),
218+
scale1=AbstractTensor(1, dtype=torch.float32),
219+
scale2=AbstractTensor(1, dtype=torch.float32),
220+
scale3=AbstractTensor(1, dtype=torch.float32),
221+
scale4=AbstractTensor(1, dtype=torch.float32),
222+
scale5=AbstractTensor(1, dtype=torch.float32),
223+
scale6=AbstractTensor(1, dtype=torch.float32),
224+
scale7=AbstractTensor(1, dtype=torch.float32),
225+
scale8=AbstractTensor(1, dtype=torch.float32),
226+
scale9=AbstractTensor(1, dtype=torch.float32),
227+
scale10=AbstractTensor(1, dtype=torch.float32),
228+
scale11=AbstractTensor(1, dtype=torch.float32),
229+
scale12=AbstractTensor(1, dtype=torch.float32),
230+
scale13=AbstractTensor(1, dtype=torch.float32),
231+
):
232+
return jittable(unet_model.forward)(
233+
sample,
234+
timestep,
235+
encoder_hidden_states,
236+
control1,
237+
control2,
238+
control3,
239+
control4,
240+
control5,
241+
control6,
242+
control7,
243+
control8,
244+
control9,
245+
control10,
246+
control11,
247+
control12,
248+
control13,
249+
scale1,
250+
scale2,
251+
scale3,
252+
scale4,
253+
scale5,
254+
scale6,
255+
scale7,
256+
scale8,
257+
scale9,
258+
scale10,
259+
scale11,
260+
scale12,
261+
scale13,
262+
)
263+
123264
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
124-
inst = CompiledUnet(context=Context(), import_to=import_to)
265+
if is_controlled:
266+
inst = CompiledControlledUnet(context=Context(), import_to=import_to)
267+
else:
268+
inst = CompiledUnet(context=Context(), import_to=import_to)
125269

126270
module_str = str(CompiledModule.get_mlir_module(inst))
127271
safe_name = utils.create_safe_name(hf_model_name, "-unet")
@@ -134,8 +278,9 @@ def main(
134278
if __name__ == "__main__":
135279
args = parser.parse_args()
136280
unet_model = UnetModel(
137-
args.hf_model_name,
281+
args.hf_model_name if not args.controlled else "CompVis/stable-diffusion-v1-4",
138282
args.hf_auth_token,
283+
args.controlled,
139284
)
140285
mod_str = export_unet_model(
141286
unet_model,
@@ -150,6 +295,7 @@ def main(
150295
args.device,
151296
args.iree_target_triple,
152297
args.vulkan_max_allocation,
298+
args.controlled,
153299
)
154300
safe_name = utils.create_safe_name(args.hf_model_name, "-unet")
155301
with open(f"{safe_name}.mlir", "w+") as f:

0 commit comments

Comments
 (0)