Skip to content

Commit 5e2351f

Browse files
RyanJDickbrandonrising
authored andcommitted
Fix FLUX output image clamping. And a few other minor fixes to make inference work with the full bfloat16 FLUX transformer model.
1 parent d705c3c commit 5e2351f

File tree

3 files changed

+25
-10
lines changed

3 files changed

+25
-10
lines changed

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from invokeai.backend.flux.model import Flux
1818
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
1919
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
20+
from invokeai.backend.model_manager.config import CheckpointConfigBase
2021
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
2122
from invokeai.backend.util.devices import TorchDevice
22-
from invokeai.backend.model_manager.config import CheckpointConfigBase
2323

2424

2525
@invocation(
@@ -90,7 +90,11 @@ def _run_diffusion(
9090
img, img_ids = self._prepare_latent_img_patches(x)
9191

9292
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
93-
is_schnell = "schnell" in transformer_info.config.config_path if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase) else ""
93+
is_schnell = (
94+
"schnell" in transformer_info.config.config_path
95+
if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase)
96+
else ""
97+
)
9498
timesteps = get_schedule(
9599
num_steps=self.num_steps,
96100
image_seq_len=img.shape[1],
@@ -161,7 +165,7 @@ def _run_vae_decoding(
161165
latents.to(torch.float32)
162166
img = vae.decode(latents)
163167

164-
img.clamp(-1, 1)
168+
img = img.clamp(-1, 1)
165169
img = rearrange(img[0], "c h w -> h w c")
166170
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
167171

invokeai/backend/flux/sampling.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,18 @@ def denoise(
104104
timesteps: list[float],
105105
guidance: float = 4.0,
106106
):
107+
dtype = model.txt_in.bias.dtype
108+
109+
# TODO(ryand): This shouldn't be necessary if we manage the dtypes properly in the caller.
110+
img = img.to(dtype=dtype)
111+
img_ids = img_ids.to(dtype=dtype)
112+
txt = txt.to(dtype=dtype)
113+
txt_ids = txt_ids.to(dtype=dtype)
114+
vec = vec.to(dtype=dtype)
115+
107116
# this is ignored for schnell
108117
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
109-
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=False):
118+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:], strict=True):
110119
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
111120
pred = model(
112121
img=img,

invokeai/backend/model_manager/load/model_loaders/flux.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Copyright (c) 2024, Brandon W. Rising and the InvokeAI Development Team
22
"""Class for Flux model loading in InvokeAI."""
33

4-
import accelerate
5-
import torch
64
from dataclasses import fields
75
from pathlib import Path
86
from typing import Any, Optional
97

8+
import accelerate
9+
import torch
1010
import yaml
1111
from safetensors.torch import load_file
1212
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
@@ -25,15 +25,15 @@
2525
from invokeai.backend.model_manager.config import (
2626
CheckpointConfigBase,
2727
CLIPEmbedDiffusersConfig,
28-
MainCheckpointConfig,
2928
MainBnbQuantized4bCheckpointConfig,
29+
MainCheckpointConfig,
3030
T5EncoderConfig,
3131
VAECheckpointConfig,
3232
)
3333
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
3434
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
35-
from invokeai.backend.util.silence_warnings import SilenceWarnings
3635
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
36+
from invokeai.backend.util.silence_warnings import SilenceWarnings
3737

3838
app_config = get_config()
3939

@@ -109,7 +109,9 @@ def _load_model(
109109
case SubModelType.Tokenizer2:
110110
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
111111
case SubModelType.TextEncoder2:
112-
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2") #TODO: Fix hf subfolder install
112+
return T5EncoderModel.from_pretrained(
113+
Path(config.path) / "text_encoder_2"
114+
) # TODO: Fix hf subfolder install
113115

114116
raise Exception("Only Checkpoint Flux models are currently supported.")
115117

@@ -153,7 +155,7 @@ def _load_from_singlefile(
153155
params = FluxParams(**filtered_data)
154156

155157
with SilenceWarnings():
156-
model = load_class(params).to(self._torch_dtype)
158+
model = load_class(params)
157159
sd = load_file(model_path)
158160
model.load_state_dict(sd, strict=False, assign=True)
159161
return model

0 commit comments

Comments
 (0)