|
| 1 | +# Copyright 2023 Nod Labs, Inc |
| 2 | +# |
| 3 | +# Licensed under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +# See https://llvm.org/LICENSE.txt for license information. |
| 5 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +import os |
| 8 | +import sys |
| 9 | + |
| 10 | +from iree import runtime as ireert |
| 11 | +from iree.compiler.ir import Context |
| 12 | +import numpy as np |
| 13 | +from shark_turbine.aot import * |
| 14 | +from turbine_models.custom_models.sd_inference import utils |
| 15 | +import torch |
| 16 | +import torch._dynamo as dynamo |
| 17 | +from diffusers import ControlNetModel as CNetModel |
| 18 | + |
| 19 | +import safetensors |
| 20 | +import argparse |
| 21 | +import re |
| 22 | + |
| 23 | +parser = argparse.ArgumentParser() |
| 24 | +parser.add_argument( |
| 25 | + "--hf_auth_token", type=str, help="The Hugging Face auth token, required" |
| 26 | +) |
| 27 | +parser.add_argument( |
| 28 | + "--hf_model_name", |
| 29 | + type=str, |
| 30 | + help="HF model name", |
| 31 | + default="lllyasviel/control_v11p_sd15_canny", |
| 32 | +) |
| 33 | +parser.add_argument( |
| 34 | + "--batch_size", type=int, default=1, help="Batch size for inference" |
| 35 | +) |
| 36 | +parser.add_argument( |
| 37 | + "--height", type=int, default=512, help="Height of Stable Diffusion" |
| 38 | +) |
| 39 | +parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") |
| 40 | +parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") |
| 41 | +parser.add_argument("--external_weight_path", type=str, default="") |
| 42 | +parser.add_argument( |
| 43 | + "--external_weights", |
| 44 | + type=str, |
| 45 | + default=None, |
| 46 | + help="saves ir without global weights for size and readability, options [safetensors]", |
| 47 | +) |
| 48 | +parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") |
| 49 | +# TODO: Bring in detection for target triple |
| 50 | +parser.add_argument( |
| 51 | + "--iree_target_triple", |
| 52 | + type=str, |
| 53 | + default="", |
| 54 | + help="Specify vulkan target triple or rocm/cuda target device.", |
| 55 | +) |
| 56 | +parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296") |
| 57 | + |
| 58 | + |
| 59 | +class ControlNetModel(torch.nn.Module): |
| 60 | + def __init__( |
| 61 | + self, model_id="lllyasviel/control_v11p_sd15_canny", low_cpu_mem_usage=False |
| 62 | + ): |
| 63 | + super().__init__() |
| 64 | + self.cnet = CNetModel.from_pretrained( |
| 65 | + model_id, |
| 66 | + low_cpu_mem_usage=low_cpu_mem_usage, |
| 67 | + ) |
| 68 | + self.in_channels = self.cnet.config.in_channels |
| 69 | + self.train(False) |
| 70 | + |
| 71 | + def forward( |
| 72 | + self, |
| 73 | + latent, |
| 74 | + timestep, |
| 75 | + text_embedding, |
| 76 | + stencil_image_input, |
| 77 | + ): |
| 78 | + # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. |
| 79 | + # TODO: guidance NOT NEEDED change in `get_input_info` later |
| 80 | + latents = torch.cat([latent] * 2) # needs to be same as controlledUNET latents |
| 81 | + stencil_image = torch.cat( |
| 82 | + [stencil_image_input] * 2 |
| 83 | + ) # needs to be same as controlledUNET latents |
| 84 | + ( |
| 85 | + down_block_res_samples, |
| 86 | + mid_block_res_sample, |
| 87 | + ) = self.cnet.forward( |
| 88 | + latents, |
| 89 | + timestep, |
| 90 | + encoder_hidden_states=text_embedding, |
| 91 | + controlnet_cond=stencil_image, |
| 92 | + return_dict=False, |
| 93 | + ) |
| 94 | + return tuple(list(down_block_res_samples) + [mid_block_res_sample]) |
| 95 | + |
| 96 | + |
| 97 | +def export_controlnet_model( |
| 98 | + controlnet_model, |
| 99 | + hf_model_name, |
| 100 | + batch_size, |
| 101 | + height, |
| 102 | + width, |
| 103 | + hf_auth_token=None, |
| 104 | + compile_to="torch", |
| 105 | + external_weights=None, |
| 106 | + external_weight_path=None, |
| 107 | + device=None, |
| 108 | + target_triple=None, |
| 109 | + max_alloc=None, |
| 110 | +): |
| 111 | + mapper = {} |
| 112 | + utils.save_external_weights( |
| 113 | + mapper, controlnet_model, external_weights, external_weight_path |
| 114 | + ) |
| 115 | + |
| 116 | + class CompiledControlnet(CompiledModule): |
| 117 | + if external_weights: |
| 118 | + params = export_parameters( |
| 119 | + controlnet_model, |
| 120 | + external=True, |
| 121 | + external_scope="", |
| 122 | + name_mapper=mapper.get, |
| 123 | + ) |
| 124 | + else: |
| 125 | + params = export_parameters(controlnet_model) |
| 126 | + |
| 127 | + def main( |
| 128 | + self, |
| 129 | + latent=AbstractTensor(1, 4, 512, 512, dtype=torch.float32), |
| 130 | + timestep=AbstractTensor(1, dtype=torch.float32), |
| 131 | + text_embedding=AbstractTensor(2, 72, 768, dtype=torch.float32), |
| 132 | + stencil_image_input=AbstractTensor(1, 3, 4096, 4096, dtype=torch.float32), |
| 133 | + ): |
| 134 | + return jittable(controlnet_model.forward)( |
| 135 | + latent, |
| 136 | + timestep, |
| 137 | + text_embedding, |
| 138 | + stencil_image_input, |
| 139 | + ) |
| 140 | + |
| 141 | + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" |
| 142 | + inst = CompiledControlnet(context=Context(), import_to=import_to) |
| 143 | + |
| 144 | + module_str = str(CompiledModule.get_mlir_module(inst)) |
| 145 | + safe_name = hf_model_name.split("/")[-1].strip() |
| 146 | + safe_name = re.sub("-", "_", safe_name) |
| 147 | + if compile_to != "vmfb": |
| 148 | + return module_str |
| 149 | + else: |
| 150 | + utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name) |
| 151 | + |
| 152 | + |
| 153 | +if __name__ == "__main__": |
| 154 | + args = parser.parse_args() |
| 155 | + controlnet_model = ControlNetModel( |
| 156 | + args.hf_model_name, |
| 157 | + ) |
| 158 | + mod_str = export_controlnet_model( |
| 159 | + controlnet_model, |
| 160 | + args.hf_model_name, |
| 161 | + args.batch_size, |
| 162 | + args.height, |
| 163 | + args.width, |
| 164 | + args.hf_auth_token, |
| 165 | + args.compile_to, |
| 166 | + args.external_weights, |
| 167 | + args.external_weight_path, |
| 168 | + args.device, |
| 169 | + args.iree_target_triple, |
| 170 | + args.vulkan_max_allocation, |
| 171 | + ) |
| 172 | + |
| 173 | + if mod_str is None: |
| 174 | + safe_name = args.hf_model_name.split("/")[-1].strip() |
| 175 | + safe_name = re.sub("-", "_", safe_name) |
| 176 | + with open(f"{safe_name}.mlir", "w+") as f: |
| 177 | + f.write(mod_str) |
| 178 | + print("Saved to", safe_name + ".mlir") |
0 commit comments