Skip to content

Commit 823c663

Browse files
committed
WIP on moving from diffusers to FLUX
1 parent d40c9ff commit 823c663

File tree

3 files changed

+155
-111
lines changed

3 files changed

+155
-111
lines changed

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 134 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
import accelerate
55
import torch
6-
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
76
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
8-
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
7+
from einops import rearrange, repeat
8+
from flux.model import Flux
9+
from flux.modules.autoencoder import AutoEncoder
10+
from flux.sampling import denoise, get_noise, get_schedule, unpack
11+
from flux.util import configs as flux_configs
912
from PIL import Image
1013
from safetensors.torch import load_file
1114
from transformers.models.auto import AutoModelForTextEncoding
@@ -21,11 +24,11 @@
2124
)
2225
from invokeai.app.invocations.primitives import ImageOutput
2326
from invokeai.app.services.shared.invocation_context import InvocationContext
24-
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
2527
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
2628
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
2729
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
2830
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
31+
from invokeai.backend.util.devices import TorchDevice
2932

3033
TFluxModelKeys = Literal["flux-schnell"]
3134
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
@@ -70,139 +73,182 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
7073

7174
@torch.no_grad()
7275
def invoke(self, context: InvocationContext) -> ImageOutput:
73-
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
76+
# model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
77+
flux_transformer_path = context.models.download_and_cache_model(
78+
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
79+
)
80+
flux_ae_path = context.models.download_and_cache_model(
81+
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
82+
)
7483

7584
# Load the conditioning data.
7685
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
7786
assert len(cond_data.conditionings) == 1
7887
flux_conditioning = cond_data.conditionings[0]
7988
assert isinstance(flux_conditioning, FLUXConditioningInfo)
8089

81-
latents = self._run_diffusion(context, model_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
82-
image = self._run_vae_decoding(context, model_path, latents)
90+
latents = self._run_diffusion(
91+
context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds
92+
)
93+
image = self._run_vae_decoding(context, flux_ae_path, latents)
8394
image_dto = context.images.save(image=image)
8495
return ImageOutput.build(image_dto)
8596

8697
def _run_diffusion(
8798
self,
8899
context: InvocationContext,
89-
flux_model_dir: Path,
100+
flux_transformer_path: Path,
90101
clip_embeddings: torch.Tensor,
91102
t5_embeddings: torch.Tensor,
92103
):
93-
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
104+
inference_dtype = TorchDevice.choose_torch_dtype()
105+
106+
# Prepare input noise.
107+
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
108+
# CPU RNG?
109+
x = get_noise(
110+
num_samples=1,
111+
height=self.height,
112+
width=self.width,
113+
device=TorchDevice.choose_torch_device(),
114+
dtype=inference_dtype,
115+
seed=self.seed,
116+
)
117+
118+
img, img_ids = self._prepare_latent_img_patches(x)
119+
120+
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
121+
is_schnell = "shnell" in str(flux_transformer_path)
122+
timesteps = get_schedule(
123+
num_steps=self.num_steps,
124+
image_seq_len=img.shape[1],
125+
shift=not is_schnell,
126+
)
127+
128+
bs, t5_seq_len, _ = t5_embeddings.shape
129+
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
94130

95131
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
96132
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
97133
# if the cache is not empty.
98134
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
99135

100-
transformer_path = flux_model_dir / "transformer"
101136
with context.models.load_local_model(
102-
model_path=transformer_path, loader=self._load_flux_transformer
137+
model_path=flux_transformer_path, loader=self._load_flux_transformer
103138
) as transformer:
104-
assert isinstance(transformer, FluxTransformer2DModel)
105-
106-
flux_pipeline_with_transformer = FluxPipeline(
107-
scheduler=scheduler,
108-
vae=None,
109-
text_encoder=None,
110-
tokenizer=None,
111-
text_encoder_2=None,
112-
tokenizer_2=None,
113-
transformer=transformer,
139+
assert isinstance(transformer, Flux)
140+
141+
x = denoise(
142+
model=transformer,
143+
img=img,
144+
img_ids=img_ids,
145+
txt=t5_embeddings,
146+
txt_ids=txt_ids,
147+
vec=clip_embeddings,
148+
timesteps=timesteps,
149+
guidance=self.guidance,
114150
)
115151

116-
dtype = torch.bfloat16
117-
t5_embeddings = t5_embeddings.to(dtype=dtype)
118-
clip_embeddings = clip_embeddings.to(dtype=dtype)
119-
120-
latents = flux_pipeline_with_transformer(
121-
height=self.height,
122-
width=self.width,
123-
num_inference_steps=self.num_steps,
124-
guidance_scale=self.guidance,
125-
generator=torch.Generator().manual_seed(self.seed),
126-
prompt_embeds=t5_embeddings,
127-
pooled_prompt_embeds=clip_embeddings,
128-
output_type="latent",
129-
return_dict=False,
130-
)[0]
131-
132-
assert isinstance(latents, torch.Tensor)
133-
return latents
152+
x = unpack(x.float(), self.height, self.width)
153+
154+
return x
155+
156+
def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
157+
"""Convert an input image in latent space to patches for diffusion.
158+
159+
This implementation was extracted from:
160+
https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/sampling.py#L32
161+
162+
Returns:
163+
tuple[Tensor, Tensor]: (img, img_ids), as defined in the original flux repo.
164+
"""
165+
bs, c, h, w = latent_img.shape
166+
167+
# Pixel unshuffle with a scale of 2, and flatten the height/width dimensions to get an array of patches.
168+
img = rearrange(latent_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
169+
if img.shape[0] == 1 and bs > 1:
170+
img = repeat(img, "1 ... -> bs ...", bs=bs)
171+
172+
# Generate patch position ids.
173+
img_ids = torch.zeros(h // 2, w // 2, 3)
174+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
175+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
176+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
177+
178+
return img, img_ids
134179

135180
def _run_vae_decoding(
136181
self,
137182
context: InvocationContext,
138-
flux_model_dir: Path,
183+
flux_ae_path: Path,
139184
latents: torch.Tensor,
140185
) -> Image.Image:
141-
vae_path = flux_model_dir / "vae"
142-
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
143-
assert isinstance(vae, AutoencoderKL)
144-
145-
flux_pipeline_with_vae = FluxPipeline(
146-
scheduler=None,
147-
vae=vae,
148-
text_encoder=None,
149-
tokenizer=None,
150-
text_encoder_2=None,
151-
tokenizer_2=None,
152-
transformer=None,
153-
)
186+
with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae:
187+
assert isinstance(vae, AutoEncoder)
188+
# TODO(ryand): Test that this works with both float16 and bfloat16.
189+
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
190+
img = vae.decode(latents)
154191

155-
latents = flux_pipeline_with_vae._unpack_latents(
156-
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
157-
)
158-
latents = (
159-
latents / flux_pipeline_with_vae.vae.config.scaling_factor
160-
) + flux_pipeline_with_vae.vae.config.shift_factor
161-
latents = latents.to(dtype=vae.dtype)
162-
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
163-
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
192+
img.clamp(-1, 1)
193+
img = rearrange(img[0], "c h w -> h w c")
194+
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
164195

165-
assert isinstance(image, Image.Image)
166-
return image
196+
return img_pil
167197

168198
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
199+
inference_dtype = TorchDevice.choose_torch_dtype()
169200
if self.quantization_type == "raw":
170-
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
171-
elif self.quantization_type == "NF4":
172-
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
201+
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
202+
params = flux_configs["flux-schnell"].params
203+
204+
# Initialize the model on the "meta" device.
173205
with accelerate.init_empty_weights():
174-
empty_model = FluxTransformer2DModel.from_config(model_config)
175-
assert isinstance(empty_model, FluxTransformer2DModel)
206+
model = Flux(params).to(inference_dtype)
176207

177-
model_nf4_path = path / "bnb_nf4"
178-
assert model_nf4_path.exists()
208+
state_dict = load_file(path)
209+
# TODO(ryand): Cast the state_dict to the appropriate dtype?
210+
model.load_state_dict(state_dict, strict=True, assign=True)
211+
elif self.quantization_type == "NF4":
212+
model_path = path.parent / "bnb_nf4.safetensors"
213+
214+
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
215+
params = flux_configs["flux-schnell"].params
216+
# Initialize the model on the "meta" device.
179217
with accelerate.init_empty_weights():
180-
model = quantize_model_nf4(empty_model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
218+
model = Flux(params)
219+
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
181220

182221
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
183222
# this on GPUs without bfloat16 support.
184-
sd = load_file(model_nf4_path / "model.safetensors")
185-
model.load_state_dict(sd, strict=True, assign=True)
186-
elif self.quantization_type == "llm_int8":
187-
model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
188-
with accelerate.init_empty_weights():
189-
empty_model = FluxTransformer2DModel.from_config(model_config)
190-
assert isinstance(empty_model, FluxTransformer2DModel)
191-
model_int8_path = path / "bnb_llm_int8"
192-
assert model_int8_path.exists()
193-
with accelerate.init_empty_weights():
194-
model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
223+
state_dict = load_file(model_path)
224+
model.load_state_dict(state_dict, strict=True, assign=True)
195225

196-
sd = load_file(model_int8_path / "model.safetensors")
197-
model.load_state_dict(sd, strict=True, assign=True)
226+
elif self.quantization_type == "llm_int8":
227+
raise NotImplementedError("LLM int8 quantization is not yet supported.")
228+
# model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
229+
# with accelerate.init_empty_weights():
230+
# empty_model = FluxTransformer2DModel.from_config(model_config)
231+
# assert isinstance(empty_model, FluxTransformer2DModel)
232+
# model_int8_path = path / "bnb_llm_int8"
233+
# assert model_int8_path.exists()
234+
# with accelerate.init_empty_weights():
235+
# model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
236+
237+
# sd = load_file(model_int8_path / "model.safetensors")
238+
# model.load_state_dict(sd, strict=True, assign=True)
198239
else:
199240
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
200241

201242
assert isinstance(model, FluxTransformer2DModel)
202243
return model
203244

204245
@staticmethod
205-
def _load_flux_vae(path: Path) -> AutoencoderKL:
206-
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
207-
assert isinstance(model, AutoencoderKL)
208-
return model
246+
def _load_flux_vae(path: Path) -> AutoEncoder:
247+
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
248+
ae_params = flux_configs["flux1-schnell"].ae_params
249+
with accelerate.init_empty_weights():
250+
ae = AutoEncoder(ae_params)
251+
252+
state_dict = load_file(path)
253+
ae.load_state_dict(state_dict, strict=True, assign=True)
254+
return ae

invokeai/backend/quantization/load_flux_model_bnb_nf4.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import accelerate
66
import torch
7-
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
7+
from flux.model import Flux
8+
from flux.util import configs as flux_configs
89
from safetensors.torch import load_file, save_file
910

1011
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
@@ -22,35 +23,37 @@ def log_time(name: str):
2223

2324

2425
def main():
25-
# Load the FLUX transformer model onto the meta device.
2626
model_path = Path(
27-
"/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/"
27+
"/data/invokeai/models/.download_cache/https__huggingface.co_black-forest-labs_flux.1-schnell_resolve_main_flux1-schnell.safetensors/flux1-schnell.safetensors"
2828
)
2929

30+
# inference_dtype = torch.bfloat16
3031
with log_time("Intialize FLUX transformer on meta device"):
31-
model_config = FluxTransformer2DModel.load_config(model_path, local_files_only=True)
32+
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
33+
params = flux_configs["flux-schnell"].params
34+
35+
# Initialize the model on the "meta" device.
3236
with accelerate.init_empty_weights():
33-
empty_model = FluxTransformer2DModel.from_config(model_config)
34-
assert isinstance(empty_model, FluxTransformer2DModel)
37+
model = Flux(params)
3538

3639
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
3740
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.
3841
modules_to_not_convert: set[str] = set()
3942

40-
model_nf4_path = model_path / "bnb_nf4"
43+
model_nf4_path = model_path.parent / "bnb_nf4.safetensors"
4144
if model_nf4_path.exists():
4245
# The quantized model already exists, load it and return it.
4346
print(f"A pre-quantized model already exists at '{model_nf4_path}'. Attempting to load it...")
4447

4548
# Replace the linear layers with NF4 quantized linear layers (still on the meta device).
4649
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
4750
model = quantize_model_nf4(
48-
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
51+
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
4952
)
5053

5154
with log_time("Load state dict into model"):
52-
sd = load_file(model_nf4_path / "model.safetensors")
53-
model.load_state_dict(sd, strict=True, assign=True)
55+
state_dict = load_file(model_nf4_path)
56+
model.load_state_dict(state_dict, strict=True, assign=True)
5457

5558
with log_time("Move model to cuda"):
5659
model = model.to("cuda")
@@ -63,30 +66,24 @@ def main():
6366

6467
with log_time("Replace linear layers with NF4 layers"), accelerate.init_empty_weights():
6568
model = quantize_model_nf4(
66-
empty_model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
69+
model, modules_to_not_convert=modules_to_not_convert, compute_dtype=torch.bfloat16
6770
)
6871

6972
with log_time("Load state dict into model"):
70-
# Load sharded state dict.
71-
files = list(model_path.glob("*.safetensors"))
72-
state_dict = dict()
73-
for file in files:
74-
sd = load_file(file)
75-
state_dict.update(sd)
76-
73+
state_dict = load_file(model_path)
74+
# TODO(ryand): Cast the state_dict to the appropriate dtype?
7775
model.load_state_dict(state_dict, strict=True, assign=True)
7876

7977
with log_time("Move model to cuda and quantize"):
8078
model = model.to("cuda")
8179

8280
with log_time("Save quantized model"):
83-
model_nf4_path.mkdir(parents=True, exist_ok=True)
84-
output_path = model_nf4_path / "model.safetensors"
85-
save_file(model.state_dict(), output_path)
81+
model_nf4_path.parent.mkdir(parents=True, exist_ok=True)
82+
save_file(model.state_dict(), model_nf4_path)
8683

87-
print(f"Successfully quantized and saved model to '{output_path}'.")
84+
print(f"Successfully quantized and saved model to '{model_nf4_path}'.")
8885

89-
assert isinstance(model, FluxTransformer2DModel)
86+
assert isinstance(model, Flux)
9087
return model
9188

9289

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dependencies = [
4040
"controlnet-aux==0.0.7",
4141
# TODO(ryand): Bump this once the next diffusers release is ready.
4242
"diffusers[torch] @ git+https://github.com/huggingface/diffusers.git@4c6152c2fb0ade468aadb417102605a07a8635d3",
43+
"flux @ git+https://github.com/black-forest-labs/flux.git@c23ae247225daba30fbd56058d247cc1b1fc20a3",
4344
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
4445
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
4546
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()

0 commit comments

Comments
 (0)