Skip to content

Commit 36d104c

Browse files
committed
Add controlnet and controlled unet
1 parent c1dc94c commit 36d104c

File tree

2 files changed

+364
-5
lines changed

2 files changed

+364
-5
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright 2023 Nod Labs, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import os
8+
import sys
9+
10+
from iree import runtime as ireert
11+
from iree.compiler.ir import Context
12+
import numpy as np
13+
from shark_turbine.aot import *
14+
from turbine_models.custom_models.sd_inference import utils
15+
import torch
16+
import torch._dynamo as dynamo
17+
from diffusers import ControlNetModel as CNetModel
18+
19+
import safetensors
20+
import argparse
21+
import re
22+
23+
parser = argparse.ArgumentParser()
24+
parser.add_argument(
25+
"--hf_auth_token", type=str, help="The Hugging Face auth token, required"
26+
)
27+
parser.add_argument(
28+
"--hf_model_name",
29+
type=str,
30+
help="HF model name",
31+
default="lllyasviel/control_v11p_sd15_canny",
32+
)
33+
parser.add_argument(
34+
"--batch_size", type=int, default=1, help="Batch size for inference"
35+
)
36+
parser.add_argument(
37+
"--height", type=int, default=512, help="Height of Stable Diffusion"
38+
)
39+
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
40+
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
41+
parser.add_argument("--external_weight_path", type=str, default="")
42+
parser.add_argument(
43+
"--external_weights",
44+
type=str,
45+
default=None,
46+
help="saves ir without global weights for size and readability, options [safetensors]",
47+
)
48+
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm")
49+
# TODO: Bring in detection for target triple
50+
parser.add_argument(
51+
"--iree_target_triple",
52+
type=str,
53+
default="",
54+
help="Specify vulkan target triple or rocm/cuda target device.",
55+
)
56+
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
57+
58+
59+
class ControlNetModel(torch.nn.Module):
60+
def __init__(
61+
self, model_id="lllyasviel/control_v11p_sd15_canny", low_cpu_mem_usage=False
62+
):
63+
super().__init__()
64+
self.cnet = CNetModel.from_pretrained(
65+
model_id,
66+
low_cpu_mem_usage=low_cpu_mem_usage,
67+
)
68+
self.in_channels = self.cnet.config.in_channels
69+
self.train(False)
70+
71+
def forward(
72+
self,
73+
latent,
74+
timestep,
75+
text_embedding,
76+
stencil_image_input,
77+
):
78+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
79+
# TODO: guidance NOT NEEDED change in `get_input_info` later
80+
latents = torch.cat([latent] * 2) # needs to be same as controlledUNET latents
81+
stencil_image = torch.cat(
82+
[stencil_image_input] * 2
83+
) # needs to be same as controlledUNET latents
84+
(
85+
down_block_res_samples,
86+
mid_block_res_sample,
87+
) = self.cnet.forward(
88+
latents,
89+
timestep,
90+
encoder_hidden_states=text_embedding,
91+
controlnet_cond=stencil_image,
92+
return_dict=False,
93+
)
94+
return tuple(list(down_block_res_samples) + [mid_block_res_sample])
95+
96+
97+
def export_controlnet_model(
98+
controlnet_model,
99+
hf_model_name,
100+
batch_size,
101+
height,
102+
width,
103+
hf_auth_token=None,
104+
compile_to="torch",
105+
external_weights=None,
106+
external_weight_path=None,
107+
device=None,
108+
target_triple=None,
109+
max_alloc=None,
110+
):
111+
mapper = {}
112+
utils.save_external_weights(
113+
mapper, controlnet_model, external_weights, external_weight_path
114+
)
115+
116+
class CompiledControlnet(CompiledModule):
117+
if external_weights:
118+
params = export_parameters(
119+
controlnet_model,
120+
external=True,
121+
external_scope="",
122+
name_mapper=mapper.get,
123+
)
124+
else:
125+
params = export_parameters(controlnet_model)
126+
127+
def main(
128+
self,
129+
latent=AbstractTensor(1, 4, 512, 512, dtype=torch.float32),
130+
timestep=AbstractTensor(1, dtype=torch.float32),
131+
text_embedding=AbstractTensor(2, 72, 768, dtype=torch.float32),
132+
stencil_image_input=AbstractTensor(1, 3, 4096, 4096, dtype=torch.float32),
133+
):
134+
return jittable(controlnet_model.forward)(
135+
latent,
136+
timestep,
137+
text_embedding,
138+
stencil_image_input,
139+
)
140+
141+
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
142+
inst = CompiledControlnet(context=Context(), import_to=import_to)
143+
144+
module_str = str(CompiledModule.get_mlir_module(inst))
145+
safe_name = hf_model_name.split("/")[-1].strip()
146+
safe_name = re.sub("-", "_", safe_name)
147+
if compile_to != "vmfb":
148+
return module_str
149+
else:
150+
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
151+
152+
153+
if __name__ == "__main__":
154+
args = parser.parse_args()
155+
controlnet_model = ControlNetModel(
156+
args.hf_model_name,
157+
)
158+
mod_str = export_controlnet_model(
159+
controlnet_model,
160+
args.hf_model_name,
161+
args.batch_size,
162+
args.height,
163+
args.width,
164+
args.hf_auth_token,
165+
args.compile_to,
166+
args.external_weights,
167+
args.external_weight_path,
168+
args.device,
169+
args.iree_target_triple,
170+
args.vulkan_max_allocation,
171+
)
172+
173+
if mod_str is None:
174+
safe_name = args.hf_model_name.split("/")[-1].strip()
175+
safe_name = re.sub("-", "_", safe_name)
176+
with open(f"{safe_name}.mlir", "w+") as f:
177+
f.write(mod_str)
178+
print("Saved to", safe_name + ".mlir")

0 commit comments

Comments
 (0)