Skip to content

Commit 1734dda

Browse files
committed
Select dev/schnell based on state dict, use correct max seq len based on dev/schnell, and shift in inference, separate vae flux params into separate config
1 parent 5143ff2 commit 1734dda

File tree

9 files changed

+366
-196
lines changed

9 files changed

+366
-196
lines changed

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from typing import Literal
23
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
34

45
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@@ -23,11 +24,12 @@ class FluxTextEncoderInvocation(BaseInvocation):
2324
description=FieldDescriptions.clip,
2425
input=Input.Connection,
2526
)
26-
t5Encoder: T5EncoderField = InputField(
27+
t5_encoder: T5EncoderField = InputField(
2728
title="T5Encoder",
2829
description=FieldDescriptions.t5Encoder,
2930
input=Input.Connection,
3031
)
32+
max_seq_len: Literal[256, 512] = InputField(description="Max sequence length for the desired flux model")
3133
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
3234

3335
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
@@ -43,21 +45,15 @@ def invoke(self, context: InvocationContext) -> ConditioningOutput:
4345
return ConditioningOutput.build(conditioning_name)
4446

4547
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
46-
# TODO: Determine the T5 max sequence length based on the model.
47-
# if self.model == "flux-schnell":
48-
max_seq_len = 256
49-
# # elif self.model == "flux-dev":
50-
# # max_seq_len = 512
51-
# else:
52-
# raise ValueError(f"Unknown model: {self.model}")
48+
max_seq_len = self.max_seq_len
5349

5450
# Load CLIP.
5551
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
5652
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
5753

5854
# Load T5.
59-
t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer)
60-
t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder)
55+
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
56+
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
6157

6258
with (
6359
clip_text_encoder_info as clip_text_encoder,

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
2020
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
2121
from invokeai.backend.util.devices import TorchDevice
22+
from invokeai.backend.model_manager.config import CheckpointConfigBase
2223

2324

2425
@invocation(
@@ -89,7 +90,7 @@ def _run_diffusion(
8990
img, img_ids = self._prepare_latent_img_patches(x)
9091

9192
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
92-
is_schnell = "schnell" in transformer_info.config.path if transformer_info.config else ""
93+
is_schnell = "schnell" in transformer_info.config.config_path if transformer_info.config and isinstance(transformer_info.config, CheckpointConfigBase) else ""
9394
timesteps = get_schedule(
9495
num_steps=self.num_steps,
9596
image_seq_len=img.shape[1],

invokeai/app/invocations/model.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import yaml
23
from time import sleep
34
from typing import Dict, List, Literal, Optional
45

@@ -16,6 +17,7 @@
1617
from invokeai.app.services.shared.invocation_context import InvocationContext
1718
from invokeai.app.shared.models import FreeUConfig
1819
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
20+
from invokeai.backend.model_manager.config import CheckpointConfigBase
1921

2022

2123
class ModelIdentifierField(BaseModel):
@@ -154,8 +156,9 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
154156

155157
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
156158
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
157-
t5Encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
159+
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
158160
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
161+
max_seq_len: Literal[256, 512] = OutputField(description=FieldDescriptions.vae, title="Max Seq Length")
159162

160163

161164
@invocation("flux_model_loader", title="Flux Main Model", tags=["model", "flux"], category="model", version="1.0.3")
@@ -189,12 +192,22 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
189192
ModelType.VAE,
190193
BaseModelType.Flux,
191194
)
195+
transformer_config = context.models.get_config(transformer)
196+
assert isinstance(transformer_config, CheckpointConfigBase)
197+
legacy_config_path = context.config.get().legacy_conf_path / transformer_config.config_path
198+
config_path = legacy_config_path.as_posix()
199+
with open(config_path, "r") as stream:
200+
try:
201+
flux_conf = yaml.safe_load(stream)
202+
except:
203+
raise
192204

193205
return FluxModelLoaderOutput(
194206
transformer=TransformerField(transformer=transformer),
195207
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
196-
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
208+
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
197209
vae=VAEField(vae=vae),
210+
max_seq_len=flux_conf['max_seq_len']
198211
)
199212

200213
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:

invokeai/backend/model_manager/load/model_loaders/flux.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
)
3333
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
3434
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
35-
from invokeai.backend.util.devices import TorchDevice
3635
from invokeai.backend.util.silence_warnings import SilenceWarnings
3736
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
3837

@@ -60,7 +59,7 @@ def _load_model(
6059
raise
6160

6261
dataclass_fields = {f.name for f in fields(AutoEncoderParams)}
63-
filtered_data = {k: v for k, v in flux_conf["params"]["ae_params"].items() if k in dataclass_fields}
62+
filtered_data = {k: v for k, v in flux_conf["params"].items() if k in dataclass_fields}
6463
params = AutoEncoderParams(**filtered_data)
6564

6665
with SilenceWarnings():

invokeai/backend/model_manager/probe.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,12 @@ def _get_checkpoint_config_path(
324324
if model_type is ModelType.Main:
325325
if base_type == BaseModelType.Flux:
326326
# TODO: Decide between dev/schnell
327-
config_file = "flux/flux1-schnell.yaml"
327+
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
328+
state_dict = checkpoint.get("state_dict") or checkpoint
329+
if 'guidance_in.out_layer.weight' in state_dict:
330+
config_file = "flux/flux1-dev.yaml"
331+
else:
332+
config_file = "flux/flux1-schnell.yaml"
328333
else:
329334
config_file = LEGACY_CONFIGS[base_type][variant_type]
330335
if isinstance(config_file, dict): # need another tier for sd-2.x models
@@ -338,7 +343,7 @@ def _get_checkpoint_config_path(
338343
)
339344
elif model_type is ModelType.VAE:
340345
config_file = (
341-
"flux/flux1-schnell.yaml"
346+
"flux/flux1-vae.yaml"
342347
if base_type is BaseModelType.Flux
343348
else "stable-diffusion/v1-inference.yaml"
344349
if base_type is BaseModelType.StableDiffusion1

invokeai/configs/flux/flux1-dev.yaml

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repo_id: "black-forest-labs/FLUX.1-dev"
22
repo_ae: "ae.safetensors"
3-
max_length: 512
3+
max_seq_len: 512
44
params:
55
in_channels: 64
66
vec_in_dim: 768
@@ -17,17 +17,3 @@ params:
1717
theta: 10_000
1818
qkv_bias: True
1919
guidance_embed: True
20-
ae_params:
21-
resolution: 256
22-
in_channels: 3
23-
ch: 128
24-
out_ch: 3
25-
ch_mult:
26-
- 1
27-
- 2
28-
- 4
29-
- 4
30-
num_res_blocks: 2
31-
z_channels: 16
32-
scale_factor: 0.3611
33-
shift_factor: 0.1159
Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
repo_id: "black-forest-labs/FLUX.1-schnell"
22
repo_ae: "ae.safetensors"
3-
t5_encoder: "google/t5-v1_1-xxl"
4-
max_length: 512
3+
max_seq_len: 256
54
params:
65
in_channels: 64
76
vec_in_dim: 768
@@ -18,17 +17,3 @@ params:
1817
theta: 10_000
1918
qkv_bias: True
2019
guidance_embed: False
21-
ae_params:
22-
resolution: 256
23-
in_channels: 3
24-
ch: 128
25-
out_ch: 3
26-
ch_mult:
27-
- 1
28-
- 2
29-
- 4
30-
- 4
31-
num_res_blocks: 2
32-
z_channels: 16
33-
scale_factor: 0.3611
34-
shift_factor: 0.1159

invokeai/configs/flux/flux1-vae.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
repo_id: "black-forest-labs/FLUX.1-schnell"
2+
repo_path: "ae.safetensors"
3+
params:
4+
resolution: 256
5+
in_channels: 3
6+
ch: 128
7+
out_ch: 3
8+
ch_mult:
9+
- 1
10+
- 2
11+
- 4
12+
- 4
13+
num_res_blocks: 2
14+
z_channels: 16
15+
scale_factor: 0.3611
16+
shift_factor: 0.1159

0 commit comments

Comments
 (0)