|
3 | 3 |
|
4 | 4 | import accelerate
|
5 | 5 | import torch
|
6 |
| -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler |
7 | 6 | from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
|
8 |
| -from diffusers.pipelines.flux.pipeline_flux import FluxPipeline |
| 7 | +from einops import rearrange, repeat |
| 8 | +from flux.model import Flux |
| 9 | +from flux.modules.autoencoder import AutoEncoder |
| 10 | +from flux.sampling import denoise, get_noise, get_schedule, unpack |
| 11 | +from flux.util import configs as flux_configs |
9 | 12 | from PIL import Image
|
10 | 13 | from safetensors.torch import load_file
|
11 | 14 | from transformers.models.auto import AutoModelForTextEncoding
|
|
21 | 24 | )
|
22 | 25 | from invokeai.app.invocations.primitives import ImageOutput
|
23 | 26 | from invokeai.app.services.shared.invocation_context import InvocationContext
|
24 |
| -from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8 |
25 | 27 | from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
|
26 | 28 | from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
|
27 | 29 | from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
|
28 | 30 | from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
|
| 31 | +from invokeai.backend.util.devices import TorchDevice |
29 | 32 |
|
30 | 33 | TFluxModelKeys = Literal["flux-schnell"]
|
31 | 34 | FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
|
@@ -70,139 +73,182 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
70 | 73 |
|
71 | 74 | @torch.no_grad()
|
72 | 75 | def invoke(self, context: InvocationContext) -> ImageOutput:
|
73 |
| - model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model]) |
| 76 | + # model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model]) |
| 77 | + flux_transformer_path = context.models.download_and_cache_model( |
| 78 | + "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors" |
| 79 | + ) |
| 80 | + flux_ae_path = context.models.download_and_cache_model( |
| 81 | + "https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors" |
| 82 | + ) |
74 | 83 |
|
75 | 84 | # Load the conditioning data.
|
76 | 85 | cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
|
77 | 86 | assert len(cond_data.conditionings) == 1
|
78 | 87 | flux_conditioning = cond_data.conditionings[0]
|
79 | 88 | assert isinstance(flux_conditioning, FLUXConditioningInfo)
|
80 | 89 |
|
81 |
| - latents = self._run_diffusion(context, model_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds) |
82 |
| - image = self._run_vae_decoding(context, model_path, latents) |
| 90 | + latents = self._run_diffusion( |
| 91 | + context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds |
| 92 | + ) |
| 93 | + image = self._run_vae_decoding(context, flux_ae_path, latents) |
83 | 94 | image_dto = context.images.save(image=image)
|
84 | 95 | return ImageOutput.build(image_dto)
|
85 | 96 |
|
86 | 97 | def _run_diffusion(
|
87 | 98 | self,
|
88 | 99 | context: InvocationContext,
|
89 |
| - flux_model_dir: Path, |
| 100 | + flux_transformer_path: Path, |
90 | 101 | clip_embeddings: torch.Tensor,
|
91 | 102 | t5_embeddings: torch.Tensor,
|
92 | 103 | ):
|
93 |
| - scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True) |
| 104 | + inference_dtype = TorchDevice.choose_torch_dtype() |
| 105 | + |
| 106 | + # Prepare input noise. |
| 107 | + # TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a |
| 108 | + # CPU RNG? |
| 109 | + x = get_noise( |
| 110 | + num_samples=1, |
| 111 | + height=self.height, |
| 112 | + width=self.width, |
| 113 | + device=TorchDevice.choose_torch_device(), |
| 114 | + dtype=inference_dtype, |
| 115 | + seed=self.seed, |
| 116 | + ) |
| 117 | + |
| 118 | + img, img_ids = self._prepare_latent_img_patches(x) |
| 119 | + |
| 120 | + # HACK(ryand): Find a better way to determine if this is a schnell model or not. |
| 121 | + is_schnell = "shnell" in str(flux_transformer_path) |
| 122 | + timesteps = get_schedule( |
| 123 | + num_steps=self.num_steps, |
| 124 | + image_seq_len=img.shape[1], |
| 125 | + shift=not is_schnell, |
| 126 | + ) |
| 127 | + |
| 128 | + bs, t5_seq_len, _ = t5_embeddings.shape |
| 129 | + txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()) |
94 | 130 |
|
95 | 131 | # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
|
96 | 132 | # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
|
97 | 133 | # if the cache is not empty.
|
98 | 134 | context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
|
99 | 135 |
|
100 |
| - transformer_path = flux_model_dir / "transformer" |
101 | 136 | with context.models.load_local_model(
|
102 |
| - model_path=transformer_path, loader=self._load_flux_transformer |
| 137 | + model_path=flux_transformer_path, loader=self._load_flux_transformer |
103 | 138 | ) as transformer:
|
104 |
| - assert isinstance(transformer, FluxTransformer2DModel) |
105 |
| - |
106 |
| - flux_pipeline_with_transformer = FluxPipeline( |
107 |
| - scheduler=scheduler, |
108 |
| - vae=None, |
109 |
| - text_encoder=None, |
110 |
| - tokenizer=None, |
111 |
| - text_encoder_2=None, |
112 |
| - tokenizer_2=None, |
113 |
| - transformer=transformer, |
| 139 | + assert isinstance(transformer, Flux) |
| 140 | + |
| 141 | + x = denoise( |
| 142 | + model=transformer, |
| 143 | + img=img, |
| 144 | + img_ids=img_ids, |
| 145 | + txt=t5_embeddings, |
| 146 | + txt_ids=txt_ids, |
| 147 | + vec=clip_embeddings, |
| 148 | + timesteps=timesteps, |
| 149 | + guidance=self.guidance, |
114 | 150 | )
|
115 | 151 |
|
116 |
| - dtype = torch.bfloat16 |
117 |
| - t5_embeddings = t5_embeddings.to(dtype=dtype) |
118 |
| - clip_embeddings = clip_embeddings.to(dtype=dtype) |
119 |
| - |
120 |
| - latents = flux_pipeline_with_transformer( |
121 |
| - height=self.height, |
122 |
| - width=self.width, |
123 |
| - num_inference_steps=self.num_steps, |
124 |
| - guidance_scale=self.guidance, |
125 |
| - generator=torch.Generator().manual_seed(self.seed), |
126 |
| - prompt_embeds=t5_embeddings, |
127 |
| - pooled_prompt_embeds=clip_embeddings, |
128 |
| - output_type="latent", |
129 |
| - return_dict=False, |
130 |
| - )[0] |
131 |
| - |
132 |
| - assert isinstance(latents, torch.Tensor) |
133 |
| - return latents |
| 152 | + x = unpack(x.float(), self.height, self.width) |
| 153 | + |
| 154 | + return x |
| 155 | + |
| 156 | + def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| 157 | + """Convert an input image in latent space to patches for diffusion. |
| 158 | +
|
| 159 | + This implementation was extracted from: |
| 160 | + https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32 |
| 161 | +
|
| 162 | + Returns: |
| 163 | + tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo. |
| 164 | + """ |
| 165 | + bs, c, h, w = latent_img.shape |
| 166 | + |
| 167 | + # Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches. |
| 168 | + img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) |
| 169 | + if img.shape[0] == 1 and bs > 1: |
| 170 | + img = repeat(img, "1 ... -> bs ...", bs=bs) |
| 171 | + |
| 172 | + # Generate patch position ids. |
| 173 | + img_ids = torch.zeros(h // 2, w // 2, 3) |
| 174 | + img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] |
| 175 | + img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] |
| 176 | + img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) |
| 177 | + |
| 178 | + return img, img_ids |
134 | 179 |
|
135 | 180 | def _run_vae_decoding(
|
136 | 181 | self,
|
137 | 182 | context: InvocationContext,
|
138 |
| - flux_model_dir: Path, |
| 183 | + flux_ae_path: Path, |
139 | 184 | latents: torch.Tensor,
|
140 | 185 | ) -> Image.Image:
|
141 |
| - vae_path = flux_model_dir / "vae" |
142 |
| - with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae: |
143 |
| - assert isinstance(vae, AutoencoderKL) |
144 |
| - |
145 |
| - flux_pipeline_with_vae = FluxPipeline( |
146 |
| - scheduler=None, |
147 |
| - vae=vae, |
148 |
| - text_encoder=None, |
149 |
| - tokenizer=None, |
150 |
| - text_encoder_2=None, |
151 |
| - tokenizer_2=None, |
152 |
| - transformer=None, |
153 |
| - ) |
| 186 | + with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae: |
| 187 | + assert isinstance(vae, AutoEncoder) |
| 188 | + # TODO(ryand): Test that this works with both float16 and bfloat16. |
| 189 | + with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()): |
| 190 | + img = vae.decode(latents) |
154 | 191 |
|
155 |
| - latents = flux_pipeline_with_vae._unpack_latents( |
156 |
| - latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor |
157 |
| - ) |
158 |
| - latents = ( |
159 |
| - latents / flux_pipeline_with_vae.vae.config.scaling_factor |
160 |
| - ) + flux_pipeline_with_vae.vae.config.shift_factor |
161 |
| - latents = latents.to(dtype=vae.dtype) |
162 |
| - image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0] |
163 |
| - image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0] |
| 192 | + img.clamp(-1, 1) |
| 193 | + img = rearrange(img[0], "c h w -> h w c") |
| 194 | + img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) |
164 | 195 |
|
165 |
| - assert isinstance(image, Image.Image) |
166 |
| - return image |
| 196 | + return img_pil |
167 | 197 |
|
168 | 198 | def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
|
| 199 | + inference_dtype = TorchDevice.choose_torch_dtype() |
169 | 200 | if self.quantization_type == "raw":
|
170 |
| - model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16) |
171 |
| - elif self.quantization_type == "NF4": |
172 |
| - model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) |
| 201 | + # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. |
| 202 | + params = flux_configs["flux-schnell"].params |
| 203 | + |
| 204 | + # Initialize the model on the "meta" device. |
173 | 205 | with accelerate.init_empty_weights():
|
174 |
| - empty_model = FluxTransformer2DModel.from_config(model_config) |
175 |
| - assert isinstance(empty_model, FluxTransformer2DModel) |
| 206 | + model = Flux(params).to(inference_dtype) |
176 | 207 |
|
177 |
| - model_nf4_path = path / "bnb_nf4" |
178 |
| - assert model_nf4_path.exists() |
| 208 | + state_dict = load_file(path) |
| 209 | + # TODO(ryand): Cast the state_dict to the appropriate dtype? |
| 210 | + model.load_state_dict(state_dict, strict=True, assign=True) |
| 211 | + elif self.quantization_type == "NF4": |
| 212 | + model_path = path.parent / "bnb_nf4.safetensors" |
| 213 | + |
| 214 | + # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. |
| 215 | + params = flux_configs["flux-schnell"].params |
| 216 | + # Initialize the model on the "meta" device. |
179 | 217 | with accelerate.init_empty_weights():
|
180 |
| - model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) |
| 218 | + model = Flux(params) |
| 219 | + model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16) |
181 | 220 |
|
182 | 221 | # TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
|
183 | 222 | # this on GPUs without bfloat16 support.
|
184 |
| - sd = load_file(model_nf4_path / "model.safetensors") |
185 |
| - model.load_state_dict(sd, strict=True, assign=True) |
186 |
| - elif self.quantization_type == "llm_int8": |
187 |
| - model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) |
188 |
| - with accelerate.init_empty_weights(): |
189 |
| - empty_model = FluxTransformer2DModel.from_config(model_config) |
190 |
| - assert isinstance(empty_model, FluxTransformer2DModel) |
191 |
| - model_int8_path = path / "bnb_llm_int8" |
192 |
| - assert model_int8_path.exists() |
193 |
| - with accelerate.init_empty_weights(): |
194 |
| - model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set()) |
| 223 | + state_dict = load_file(model_path) |
| 224 | + model.load_state_dict(state_dict, strict=True, assign=True) |
195 | 225 |
|
196 |
| - sd = load_file(model_int8_path / "model.safetensors") |
197 |
| - model.load_state_dict(sd, strict=True, assign=True) |
| 226 | + elif self.quantization_type == "llm_int8": |
| 227 | + raise NotImplementedError("LLM int8 quantization is not yet supported.") |
| 228 | + # model_config = FluxTransformer2DModel.load_config(path, local_files_only=True) |
| 229 | + # with accelerate.init_empty_weights(): |
| 230 | + # empty_model = FluxTransformer2DModel.from_config(model_config) |
| 231 | + # assert isinstance(empty_model, FluxTransformer2DModel) |
| 232 | + # model_int8_path = path / "bnb_llm_int8" |
| 233 | + # assert model_int8_path.exists() |
| 234 | + # with accelerate.init_empty_weights(): |
| 235 | + # model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set()) |
| 236 | + |
| 237 | + # sd = load_file(model_int8_path / "model.safetensors") |
| 238 | + # model.load_state_dict(sd, strict=True, assign=True) |
198 | 239 | else:
|
199 | 240 | raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
|
200 | 241 |
|
201 | 242 | assert isinstance(model, FluxTransformer2DModel)
|
202 | 243 | return model
|
203 | 244 |
|
204 | 245 | @staticmethod
|
205 |
| - def _load_flux_vae(path: Path) -> AutoencoderKL: |
206 |
| - model = AutoencoderKL.from_pretrained(path, local_files_only=True) |
207 |
| - assert isinstance(model, AutoencoderKL) |
208 |
| - return model |
| 246 | + def _load_flux_vae(path: Path) -> AutoEncoder: |
| 247 | + # TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config. |
| 248 | + ae_params = flux_configs["flux1-schnell"].ae_params |
| 249 | + with accelerate.init_empty_weights(): |
| 250 | + ae = AutoEncoder(ae_params) |
| 251 | + |
| 252 | + state_dict = load_file(path) |
| 253 | + ae.load_state_dict(state_dict, strict=True, assign=True) |
| 254 | + return ae |
0 commit comments