Skip to content

Commit 8c0ac85

Browse files
committed
Rename params for flux and flux vae, add comments explaining use of the config_path in model config
1 parent 5477950 commit 8c0ac85

File tree

5 files changed

+63
-85
lines changed

5 files changed

+63
-85
lines changed

invokeai/backend/flux/util.py

Lines changed: 40 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -25,82 +25,48 @@ class ModelSpec:
2525
}
2626

2727

28-
ae_params = AutoEncoderParams(
29-
resolution=256,
30-
in_channels=3,
31-
ch=128,
32-
out_ch=3,
33-
ch_mult=[1, 2, 4, 4],
34-
num_res_blocks=2,
35-
z_channels=16,
36-
scale_factor=0.3611,
37-
shift_factor=0.1159,
38-
)
28+
ae_params = {
29+
"flux": AutoEncoderParams(
30+
resolution=256,
31+
in_channels=3,
32+
ch=128,
33+
out_ch=3,
34+
ch_mult=[1, 2, 4, 4],
35+
num_res_blocks=2,
36+
z_channels=16,
37+
scale_factor=0.3611,
38+
shift_factor=0.1159,
39+
)
40+
}
3941

4042

41-
configs = {
42-
"flux-dev": ModelSpec(
43-
repo_id="black-forest-labs/FLUX.1-dev",
44-
repo_flow="flux1-dev.safetensors",
45-
repo_ae="ae.safetensors",
46-
ckpt_path=os.getenv("FLUX_DEV"),
47-
params=FluxParams(
48-
in_channels=64,
49-
vec_in_dim=768,
50-
context_in_dim=4096,
51-
hidden_size=3072,
52-
mlp_ratio=4.0,
53-
num_heads=24,
54-
depth=19,
55-
depth_single_blocks=38,
56-
axes_dim=[16, 56, 56],
57-
theta=10_000,
58-
qkv_bias=True,
59-
guidance_embed=True,
60-
),
61-
ae_path=os.getenv("AE"),
62-
ae_params=AutoEncoderParams(
63-
resolution=256,
64-
in_channels=3,
65-
ch=128,
66-
out_ch=3,
67-
ch_mult=[1, 2, 4, 4],
68-
num_res_blocks=2,
69-
z_channels=16,
70-
scale_factor=0.3611,
71-
shift_factor=0.1159,
72-
),
43+
params = {
44+
"flux-dev": FluxParams(
45+
in_channels=64,
46+
vec_in_dim=768,
47+
context_in_dim=4096,
48+
hidden_size=3072,
49+
mlp_ratio=4.0,
50+
num_heads=24,
51+
depth=19,
52+
depth_single_blocks=38,
53+
axes_dim=[16, 56, 56],
54+
theta=10_000,
55+
qkv_bias=True,
56+
guidance_embed=True,
7357
),
74-
"flux-schnell": ModelSpec(
75-
repo_id="black-forest-labs/FLUX.1-schnell",
76-
repo_flow="flux1-schnell.safetensors",
77-
repo_ae="ae.safetensors",
78-
ckpt_path=os.getenv("FLUX_SCHNELL"),
79-
params=FluxParams(
80-
in_channels=64,
81-
vec_in_dim=768,
82-
context_in_dim=4096,
83-
hidden_size=3072,
84-
mlp_ratio=4.0,
85-
num_heads=24,
86-
depth=19,
87-
depth_single_blocks=38,
88-
axes_dim=[16, 56, 56],
89-
theta=10_000,
90-
qkv_bias=True,
91-
guidance_embed=False,
92-
),
93-
ae_path=os.getenv("AE"),
94-
ae_params=AutoEncoderParams(
95-
resolution=256,
96-
in_channels=3,
97-
ch=128,
98-
out_ch=3,
99-
ch_mult=[1, 2, 4, 4],
100-
num_res_blocks=2,
101-
z_channels=16,
102-
scale_factor=0.3611,
103-
shift_factor=0.1159,
104-
),
58+
"flux-schnell": FluxParams(
59+
in_channels=64,
60+
vec_in_dim=768,
61+
context_in_dim=4096,
62+
hidden_size=3072,
63+
mlp_ratio=4.0,
64+
num_heads=24,
65+
depth=19,
66+
depth_single_blocks=38,
67+
axes_dim=[16, 56, 56],
68+
theta=10_000,
69+
qkv_bias=True,
70+
guidance_embed=False,
10571
),
10672
}

invokeai/backend/model_manager/load/model_loaders/flux.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from invokeai.app.services.config.config_default import get_config
1313
from invokeai.backend.flux.model import Flux
1414
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
15-
from invokeai.backend.flux.util import ae_params, configs
15+
from invokeai.backend.flux.util import ae_params, params
1616
from invokeai.backend.model_manager import (
1717
AnyModel,
1818
AnyModelConfig,
@@ -59,7 +59,7 @@ def _load_model(
5959
model_path = Path(config.path)
6060

6161
with SilenceWarnings():
62-
model = AutoEncoder(ae_params)
62+
model = AutoEncoder(ae_params[config.config_path])
6363
sd = load_file(model_path)
6464
model.load_state_dict(sd, assign=True)
6565
model.to(dtype=self._torch_dtype)
@@ -188,7 +188,7 @@ def _load_from_singlefile(
188188
model_path = Path(config.path)
189189

190190
with SilenceWarnings():
191-
model = Flux(configs[config.config_path].params)
191+
model = Flux(params[config.config_path])
192192
sd = load_file(model_path)
193193
model.load_state_dict(sd, assign=True)
194194
return model
@@ -227,7 +227,7 @@ def _load_from_singlefile(
227227

228228
with SilenceWarnings():
229229
with accelerate.init_empty_weights():
230-
model = Flux(configs[config.config_path].params)
230+
model = Flux(params[config.config_path])
231231
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
232232
sd = load_file(model_path)
233233
model.load_state_dict(sd, assign=True)

invokeai/backend/model_manager/probe.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,8 +329,16 @@ def _get_checkpoint_config_path(
329329
checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
330330
state_dict = checkpoint.get("state_dict") or checkpoint
331331
if "guidance_in.out_layer.weight" in state_dict:
332+
# For flux, this is a key in invokeai.backend.flux.util.params
333+
# Due to model type and format being the descriminator for model configs this
334+
# is used rather than attempting to support flux with separate model types and format
335+
# If changed in the future, please fix me
332336
config_file = "flux-dev"
333337
else:
338+
# For flux, this is a key in invokeai.backend.flux.util.params
339+
# Due to model type and format being the descriminator for model configs this
340+
# is used rather than attempting to support flux with separate model types and format
341+
# If changed in the future, please fix me
334342
config_file = "flux-schnell"
335343
else:
336344
config_file = LEGACY_CONFIGS[base_type][variant_type]
@@ -345,7 +353,11 @@ def _get_checkpoint_config_path(
345353
)
346354
elif model_type is ModelType.VAE:
347355
config_file = (
348-
"flux/flux1-vae.yaml"
356+
# For flux, this is a key in invokeai.backend.flux.util.ae_params
357+
# Due to model type and format being the descriminator for model configs this
358+
# is used rather than attempting to support flux with separate model types and format
359+
# If changed in the future, please fix me
360+
"flux"
349361
if base_type is BaseModelType.Flux
350362
else "stable-diffusion/v1-inference.yaml"
351363
if base_type is BaseModelType.StableDiffusion1

invokeai/backend/quantization/scripts/load_flux_model_bnb_llm_int8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from safetensors.torch import load_file, save_file
55

66
from invokeai.backend.flux.model import Flux
7-
from invokeai.backend.flux.util import configs as flux_configs
7+
from invokeai.backend.flux.util import params
88
from invokeai.backend.quantization.bnb_llm_int8 import quantize_model_llm_int8
99
from invokeai.backend.quantization.scripts.load_flux_model_bnb_nf4 import log_time
1010

@@ -22,11 +22,11 @@ def main():
2222

2323
with log_time("Intialize FLUX transformer on meta device"):
2424
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
25-
params = flux_configs["flux-schnell"].params
25+
p = params["flux-schnell"]
2626

2727
# Initialize the model on the "meta" device.
2828
with accelerate.init_empty_weights():
29-
model = Flux(params)
29+
model = Flux(p)
3030

3131
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
3232
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.

invokeai/backend/quantization/scripts/load_flux_model_bnb_nf4.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from safetensors.torch import load_file, save_file
88

99
from invokeai.backend.flux.model import Flux
10-
from invokeai.backend.flux.util import configs as flux_configs
10+
from invokeai.backend.flux.util import params
1111
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
1212

1313

@@ -35,11 +35,11 @@ def main():
3535
# inference_dtype = torch.bfloat16
3636
with log_time("Intialize FLUX transformer on meta device"):
3737
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
38-
params = flux_configs["flux-schnell"].params
38+
p = params["flux-schnell"]
3939

4040
# Initialize the model on the "meta" device.
4141
with accelerate.init_empty_weights():
42-
model = Flux(params)
42+
model = Flux(p)
4343

4444
# TODO(ryand): We may want to add some modules to not quantize here (e.g. the proj_out layer). See the accelerate
4545
# `get_keys_to_not_convert(...)` function for a heuristic to determine which modules to not quantize.

0 commit comments

Comments
 (0)