Skip to content

Commit 9c4576c

Browse files
committed
Run ruff, setup initial text to image node
1 parent 1822057 commit 9c4576c

File tree

15 files changed

+290
-123
lines changed

15 files changed

+290
-123
lines changed

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
import torch
2-
3-
4-
from einops import repeat
5-
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
62
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
73

84
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
95
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
106
from invokeai.app.invocations.model import CLIPField, T5EncoderField
117
from invokeai.app.invocations.primitives import ConditioningOutput
128
from invokeai.app.services.shared.invocation_context import InvocationContext
13-
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
14-
from invokeai.backend.util.devices import TorchDevice
159
from invokeai.backend.flux.modules.conditioner import HFEncoder
10+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
1611

1712

1813
@invocation(

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
1-
from typing import Literal
2-
3-
import accelerate
41
import torch
5-
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
6-
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
2+
from einops import rearrange, repeat
73
from PIL import Image
8-
from safetensors.torch import load_file
9-
from transformers.models.auto import AutoModelForTextEncoding
104

115
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
126
from invokeai.app.invocations.fields import (
@@ -20,23 +14,12 @@
2014
from invokeai.app.invocations.model import TransformerField, VAEField
2115
from invokeai.app.invocations.primitives import ImageOutput
2216
from invokeai.app.services.shared.invocation_context import InvocationContext
23-
from invokeai.backend.quantization.bnb_nf4 import quantize_model_nf4
24-
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
25-
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
17+
from invokeai.backend.flux.model import Flux
18+
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
19+
from invokeai.backend.flux.sampling import denoise, get_noise, get_schedule, unpack
2620
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
2721
from invokeai.backend.util.devices import TorchDevice
2822

29-
TFluxModelKeys = Literal["flux-schnell"]
30-
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
31-
32-
33-
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
34-
base_class = FluxTransformer2DModel
35-
36-
37-
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
38-
auto_class = AutoModelForTextEncoding
39-
4023

4124
@invocation(
4225
"flux_text_to_image",
@@ -78,7 +61,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
7861
assert isinstance(flux_conditioning, FLUXConditioningInfo)
7962

8063
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
81-
image = self._run_vae_decoding(context, latents)
64+
image = self._run_vae_decoding(context, flux_ae_path, latents)
8265
image_dto = context.images.save(image=image)
8366
return ImageOutput.build(image_dto)
8467

@@ -89,14 +72,40 @@ def _run_diffusion(
8972
t5_embeddings: torch.Tensor,
9073
):
9174
transformer_info = context.models.load(self.transformer.transformer)
75+
inference_dtype = TorchDevice.choose_torch_dtype()
76+
77+
# Prepare input noise.
78+
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
79+
# CPU RNG?
80+
x = get_noise(
81+
num_samples=1,
82+
height=self.height,
83+
width=self.width,
84+
device=TorchDevice.choose_torch_device(),
85+
dtype=inference_dtype,
86+
seed=self.seed,
87+
)
88+
89+
img, img_ids = self._prepare_latent_img_patches(x)
90+
91+
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
92+
is_schnell = "shnell" in transformer_info.config.path if transformer_info.config else ""
93+
timesteps = get_schedule(
94+
num_steps=self.num_steps,
95+
image_seq_len=img.shape[1],
96+
shift=not is_schnell,
97+
)
98+
99+
bs, t5_seq_len, _ = t5_embeddings.shape
100+
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
92101

93102
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
94103
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
95104
# if the cache is not empty.
96-
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
105+
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
97106

98107
with transformer_info as transformer:
99-
assert isinstance(transformer, FluxTransformer2DModel)
108+
assert isinstance(transformer, Flux)
100109

101110
x = denoise(
102111
model=transformer,
@@ -144,21 +153,13 @@ def _run_vae_decoding(
144153
) -> Image.Image:
145154
vae_info = context.models.load(self.vae.vae)
146155
with vae_info as vae:
147-
assert isinstance(vae, AutoencoderKL)
156+
assert isinstance(vae, AutoEncoder)
157+
# TODO(ryand): Test that this works with both float16 and bfloat16.
158+
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
159+
img = vae.decode(latents)
148160

149161
img.clamp(-1, 1)
150162
img = rearrange(img[0], "c h w -> h w c")
151163
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
152164

153-
latents = flux_pipeline_with_vae._unpack_latents(
154-
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
155-
)
156-
latents = (
157-
latents / flux_pipeline_with_vae.vae.config.scaling_factor
158-
) + flux_pipeline_with_vae.vae.config.shift_factor
159-
latents = latents.to(dtype=vae.dtype)
160-
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
161-
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
162-
163-
assert isinstance(image, Image.Image)
164-
return image
165+
return img_pil

invokeai/app/invocations/model.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
from time import sleep
3-
from typing import List, Optional, Literal, Dict
3+
from typing import Dict, List, Literal, Optional
44

55
from pydantic import BaseModel, Field
66

@@ -12,10 +12,10 @@
1212
invocation_output,
1313
)
1414
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
15+
from invokeai.app.services.model_records import ModelRecordChanges
1516
from invokeai.app.services.shared.invocation_context import InvocationContext
1617
from invokeai.app.shared.models import FreeUConfig
17-
from invokeai.app.services.model_records import ModelRecordChanges
18-
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat
18+
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
1919

2020

2121
class ModelIdentifierField(BaseModel):
@@ -132,31 +132,22 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
132132

133133
return ModelIdentifierOutput(model=self.model)
134134

135-
T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"]
135+
136+
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
136137
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
137138
"base": {
138-
"text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2",
139-
"tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2",
140-
"text_encoder_name": "FLUX.1-schnell_text_encoder_2",
141-
"tokenizer_name": "FLUX.1-schnell_tokenizer_2",
139+
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
140+
"name": "t5_base_encoder",
142141
"format": ModelFormat.T5Encoder,
143142
},
144143
"8b_quantized": {
145-
"text_encoder_repo": "hf_repo1",
146-
"tokenizer_repo": "hf_repo1",
147-
"text_encoder_name": "hf_repo1",
148-
"tokenizer_name": "hf_repo1",
149-
"format": ModelFormat.T5Encoder8b,
150-
},
151-
"4b_quantized": {
152-
"text_encoder_repo": "hf_repo2",
153-
"tokenizer_repo": "hf_repo2",
154-
"text_encoder_name": "hf_repo2",
155-
"tokenizer_name": "hf_repo2",
156-
"format": ModelFormat.T5Encoder8b,
144+
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
145+
"name": "t5_8b_quantized_encoder",
146+
"format": ModelFormat.T5Encoder,
157147
},
158148
}
159149

150+
160151
@invocation_output("flux_model_loader_output")
161152
class FluxModelLoaderOutput(BaseInvocationOutput):
162153
"""Flux base model loader output"""
@@ -176,7 +167,7 @@ class FluxModelLoaderInvocation(BaseInvocation):
176167
ui_type=UIType.FluxMainModel,
177168
input=Input.Direct,
178169
)
179-
170+
180171
t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.")
181172

182173
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
@@ -189,7 +180,15 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
189180
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
190181
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
191182
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
192-
vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux)
183+
vae = self._install_model(
184+
context,
185+
SubModelType.VAE,
186+
"FLUX.1-schnell_ae",
187+
"black-forest-labs/FLUX.1-schnell::ae.safetensors",
188+
ModelFormat.Checkpoint,
189+
ModelType.VAE,
190+
BaseModelType.Flux,
191+
)
193192

194193
return FluxModelLoaderOutput(
195194
transformer=TransformerField(transformer=transformer),
@@ -198,33 +197,59 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
198197
vae=VAEField(vae=vae),
199198
)
200199

201-
def _get_model(self, context: InvocationContext, submodel:SubModelType) -> ModelIdentifierField:
202-
match(submodel):
200+
def _get_model(self, context: InvocationContext, submodel: SubModelType) -> ModelIdentifierField:
201+
match submodel:
203202
case SubModelType.Transformer:
204203
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
205204
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
206-
return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any)
207-
case SubModelType.TextEncoder2:
208-
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
209-
case SubModelType.Tokenizer2:
210-
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
205+
return self._install_model(
206+
context,
207+
submodel,
208+
"clip-vit-large-patch14",
209+
"openai/clip-vit-large-patch14",
210+
ModelFormat.Diffusers,
211+
ModelType.CLIPEmbed,
212+
BaseModelType.Any,
213+
)
214+
case submodel if submodel in [SubModelType.Tokenizer2, SubModelType.TextEncoder2]:
215+
return self._install_model(
216+
context,
217+
submodel,
218+
T5_ENCODER_MAP[self.t5_encoder]["name"],
219+
T5_ENCODER_MAP[self.t5_encoder]["repo"],
220+
ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]),
221+
ModelType.T5Encoder,
222+
BaseModelType.Any,
223+
)
211224
case _:
212-
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
213-
214-
def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType):
215-
if (models := context.models.search_by_attrs(name=name, base=base, type=type)):
225+
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
226+
227+
def _install_model(
228+
self,
229+
context: InvocationContext,
230+
submodel: SubModelType,
231+
name: str,
232+
repo_id: str,
233+
format: ModelFormat,
234+
type: ModelType,
235+
base: BaseModelType,
236+
):
237+
if models := context.models.search_by_attrs(name=name, base=base, type=type):
216238
if len(models) != 1:
217239
raise Exception(f"Multiple models detected for selected model with name {name}")
218240
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
219241
else:
220242
model_path = context.models.download_and_cache_model(repo_id)
221-
config = ModelRecordChanges(name = name, base = base, type=type, format=format)
243+
config = ModelRecordChanges(name=name, base=base, type=type, format=format)
222244
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
223245
while not model_install_job.in_terminal_state:
224246
sleep(0.01)
225247
if not model_install_job.config_out:
226248
raise Exception(f"Failed to install {name}")
227-
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel})
249+
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(
250+
update={"submodel_type": submodel}
251+
)
252+
228253

229254
@invocation(
230255
"main_model_loader",

invokeai/app/services/model_records/model_records_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def search_by_attr(
301301
for row in result:
302302
try:
303303
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
304-
except pydantic.ValidationError as e:
304+
except pydantic.ValidationError:
305305
# We catch this error so that the app can still run if there are invalid model configs in the database.
306306
# One reason that an invalid model config might be in the database is if someone had to rollback from a
307307
# newer version of the app that added a new model type.

invokeai/app/services/shared/invocation_context.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -465,18 +465,20 @@ def download_and_cache_model(
465465
return self._services.model_manager.install.download_and_cache_model(source=source)
466466

467467
def import_local_model(
468-
self,
469-
model_path: Path,
470-
config: Optional[ModelRecordChanges] = None,
471-
access_token: Optional[str] = None,
472-
inplace: Optional[bool] = False,
468+
self,
469+
model_path: Path,
470+
config: Optional[ModelRecordChanges] = None,
471+
access_token: Optional[str] = None,
472+
inplace: Optional[bool] = False,
473473
):
474474
"""
475475
TODO: Fill out description of this method
476476
"""
477477
if not model_path.exists():
478478
raise Exception("Models provided to import_local_model must already exist on disk")
479-
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, access_token=access_token, inplace=inplace)
479+
return self._services.model_manager.install.heuristic_import(
480+
str(model_path), config=config, access_token=access_token, inplace=inplace
481+
)
480482

481483
def load_local_model(
482484
self,

invokeai/backend/flux/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso
2727
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
2828
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
2929
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
30-
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
30+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

invokeai/backend/flux/model.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@
33
import torch
44
from torch import Tensor, nn
55

6-
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
7-
MLPEmbedder, SingleStreamBlock,
8-
timestep_embedding)
6+
from invokeai.backend.flux.modules.layers import (
7+
DoubleStreamBlock,
8+
EmbedND,
9+
LastLayer,
10+
MLPEmbedder,
11+
SingleStreamBlock,
12+
timestep_embedding,
13+
)
14+
915

1016
@dataclass
1117
class FluxParams:
@@ -35,9 +41,7 @@ def __init__(self, params: FluxParams):
3541
self.in_channels = params.in_channels
3642
self.out_channels = self.in_channels
3743
if params.hidden_size % params.num_heads != 0:
38-
raise ValueError(
39-
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
40-
)
44+
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
4145
pe_dim = params.hidden_size // params.num_heads
4246
if sum(params.axes_dim) != pe_dim:
4347
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
@@ -108,4 +112,4 @@ def forward(
108112
img = img[:, txt.shape[1] :, ...]
109113

110114
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
111-
return img
115+
return img

invokeai/backend/flux/modules/autoencoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,4 +309,4 @@ def decode(self, z: Tensor) -> Tensor:
309309
return self.decoder(z)
310310

311311
def forward(self, x: Tensor) -> Tensor:
312-
return self.decode(self.encode(x))
312+
return self.decode(self.encode(x))

0 commit comments

Comments
 (0)