Skip to content

Commit f4f5c46

Browse files
committed
Add backend functions and classes for Flux implementation, Update the way flux encoders/tokenizers are loaded for prompt encoding, Update way flux vae is loaded
1 parent 53052cf commit f4f5c46

File tree

19 files changed

+1340
-197
lines changed

19 files changed

+1340
-197
lines changed

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
2+
3+
4+
from einops import repeat
25
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
3-
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
6+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
47

58
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
69
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
@@ -9,6 +12,7 @@
912
from invokeai.app.services.shared.invocation_context import InvocationContext
1013
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
1114
from invokeai.backend.util.devices import TorchDevice
15+
from invokeai.backend.flux.modules.conditioner import HFEncoder
1216

1317

1418
@invocation(
@@ -69,26 +73,15 @@ def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torc
6973
assert isinstance(clip_text_encoder, CLIPTextModel)
7074
assert isinstance(t5_text_encoder, T5EncoderModel)
7175
assert isinstance(clip_tokenizer, CLIPTokenizer)
72-
assert isinstance(t5_tokenizer, T5TokenizerFast)
76+
assert isinstance(t5_tokenizer, T5Tokenizer)
77+
78+
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
79+
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, max_seq_len)
7380

74-
pipeline = FluxPipeline(
75-
scheduler=None,
76-
vae=None,
77-
text_encoder=clip_text_encoder,
78-
tokenizer=clip_tokenizer,
79-
text_encoder_2=t5_text_encoder,
80-
tokenizer_2=t5_tokenizer,
81-
transformer=None,
82-
)
81+
prompt = [self.positive_prompt]
82+
prompt_embeds = t5_encoder(prompt)
8383

84-
# prompt_embeds: T5 embeddings
85-
# pooled_prompt_embeds: CLIP embeddings
86-
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
87-
prompt=self.positive_prompt,
88-
prompt_2=self.positive_prompt,
89-
device=TorchDevice.choose_torch_device(),
90-
max_sequence_length=max_seq_len,
91-
)
84+
pooled_prompt_embeds = clip_encoder(prompt)
9285

9386
assert isinstance(prompt_embeds, torch.Tensor)
9487
assert isinstance(pooled_prompt_embeds, torch.Tensor)

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,15 @@ def _run_diffusion(
8585
clip_embeddings: torch.Tensor,
8686
t5_embeddings: torch.Tensor,
8787
):
88-
scheduler_info = context.models.load(self.transformer.scheduler)
8988
transformer_info = context.models.load(self.transformer.transformer)
9089

9190
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
9291
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
9392
# if the cache is not empty.
9493
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
9594

96-
with transformer_info as transformer, scheduler_info as scheduler:
95+
with transformer_info as transformer:
9796
assert isinstance(transformer, FluxTransformer2DModel)
98-
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
9997

10098
flux_pipeline_with_transformer = FluxPipeline(
10199
scheduler=scheduler,

invokeai/app/invocations/model.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
2-
from typing import List, Optional
2+
from time import sleep
3+
from typing import List, Optional, Literal, Dict
34

45
from pydantic import BaseModel, Field
56

@@ -13,7 +14,8 @@
1314
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
1415
from invokeai.app.services.shared.invocation_context import InvocationContext
1516
from invokeai.app.shared.models import FreeUConfig
16-
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
17+
from invokeai.app.services.model_records import ModelRecordChanges
18+
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType, ModelFormat
1719

1820

1921
class ModelIdentifierField(BaseModel):
@@ -62,7 +64,6 @@ class CLIPField(BaseModel):
6264

6365
class TransformerField(BaseModel):
6466
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
65-
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
6667

6768

6869
class T5EncoderField(BaseModel):
@@ -131,6 +132,30 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
131132

132133
return ModelIdentifierOutput(model=self.model)
133134

135+
T5_ENCODER_OPTIONS = Literal["base", "16b_quantized", "8b_quantized"]
136+
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
137+
"base": {
138+
"text_encoder_repo": "black-forest-labs/FLUX.1-schnell::text_encoder_2",
139+
"tokenizer_repo": "black-forest-labs/FLUX.1-schnell::tokenizer_2",
140+
"text_encoder_name": "FLUX.1-schnell_text_encoder_2",
141+
"tokenizer_name": "FLUX.1-schnell_tokenizer_2",
142+
"format": ModelFormat.T5Encoder,
143+
},
144+
"8b_quantized": {
145+
"text_encoder_repo": "hf_repo1",
146+
"tokenizer_repo": "hf_repo1",
147+
"text_encoder_name": "hf_repo1",
148+
"tokenizer_name": "hf_repo1",
149+
"format": ModelFormat.T5Encoder8b,
150+
},
151+
"4b_quantized": {
152+
"text_encoder_repo": "hf_repo2",
153+
"tokenizer_repo": "hf_repo2",
154+
"text_encoder_name": "hf_repo2",
155+
"tokenizer_name": "hf_repo2",
156+
"format": ModelFormat.T5Encoder8b,
157+
},
158+
}
134159

135160
@invocation_output("flux_model_loader_output")
136161
class FluxModelLoaderOutput(BaseInvocationOutput):
@@ -151,29 +176,55 @@ class FluxModelLoaderInvocation(BaseInvocation):
151176
ui_type=UIType.FluxMainModel,
152177
input=Input.Direct,
153178
)
179+
180+
t5_encoder: T5_ENCODER_OPTIONS = InputField(description="The T5 Encoder model to use.")
154181

155182
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
156183
model_key = self.model.key
157184

158-
# TODO: not found exceptions
159185
if not context.models.exists(model_key):
160186
raise Exception(f"Unknown model: {model_key}")
161-
162-
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
163-
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
164-
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
165-
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
166-
tokenizer2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
167-
text_encoder2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
168-
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
187+
transformer = self._get_model(context, SubModelType.Transformer)
188+
tokenizer = self._get_model(context, SubModelType.Tokenizer)
189+
tokenizer2 = self._get_model(context, SubModelType.Tokenizer2)
190+
clip_encoder = self._get_model(context, SubModelType.TextEncoder)
191+
t5_encoder = self._get_model(context, SubModelType.TextEncoder2)
192+
vae = self._install_model(context, SubModelType.VAE, "FLUX.1-schnell_ae", "black-forest-labs/FLUX.1-schnell::ae.safetensors", ModelFormat.Checkpoint, ModelType.VAE, BaseModelType.Flux)
169193

170194
return FluxModelLoaderOutput(
171-
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
172-
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
173-
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=text_encoder2),
195+
transformer=TransformerField(transformer=transformer),
196+
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
197+
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
174198
vae=VAEField(vae=vae),
175199
)
176200

201+
def _get_model(self, context: InvocationContext, submodel:SubModelType) -> ModelIdentifierField:
202+
match(submodel):
203+
case SubModelType.Transformer:
204+
return self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
205+
case submodel if submodel in [SubModelType.Tokenizer, SubModelType.TextEncoder]:
206+
return self._install_model(context, submodel, "clip-vit-large-patch14", "openai/clip-vit-large-patch14", ModelFormat.Diffusers, ModelType.CLIPEmbed, BaseModelType.Any)
207+
case SubModelType.TextEncoder2:
208+
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["text_encoder_name"], T5_ENCODER_MAP[self.t5_encoder]["text_encoder_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
209+
case SubModelType.Tokenizer2:
210+
return self._install_model(context, submodel, T5_ENCODER_MAP[self.t5_encoder]["tokenizer_name"], T5_ENCODER_MAP[self.t5_encoder]["tokenizer_repo"], ModelFormat(T5_ENCODER_MAP[self.t5_encoder]["format"]), ModelType.T5Encoder, BaseModelType.Any)
211+
case _:
212+
raise Exception(f"{submodel.value} is not a supported submodule for a flux model")
213+
214+
def _install_model(self, context: InvocationContext, submodel:SubModelType, name: str, repo_id: str, format: ModelFormat, type: ModelType, base: BaseModelType):
215+
if (models := context.models.search_by_attrs(name=name, base=base, type=type)):
216+
if len(models) != 1:
217+
raise Exception(f"Multiple models detected for selected model with name {name}")
218+
return ModelIdentifierField.from_config(models[0]).model_copy(update={"submodel_type": submodel})
219+
else:
220+
model_path = context.models.download_and_cache_model(repo_id)
221+
config = ModelRecordChanges(name = name, base = base, type=type, format=format)
222+
model_install_job = context.models.import_local_model(model_path=model_path, config=config)
223+
while not model_install_job.in_terminal_state:
224+
sleep(0.01)
225+
if not model_install_job.config_out:
226+
raise Exception(f"Failed to install {name}")
227+
return ModelIdentifierField.from_config(model_install_job.config_out).model_copy(update={"submodel_type": submodel})
177228

178229
@invocation(
179230
"main_model_loader",

invokeai/app/services/model_records/model_records_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
7777
type: Optional[ModelType] = Field(description="Type of model", default=None)
7878
key: Optional[str] = Field(description="Database ID for this model", default=None)
7979
hash: Optional[str] = Field(description="hash of model file", default=None)
80+
format: Optional[str] = Field(description="format of model file", default=None)
8081
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
8182
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
8283
description="Default settings for this model", default=None

invokeai/app/services/model_records/model_records_sql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def search_by_attr(
301301
for row in result:
302302
try:
303303
model_config = ModelConfigFactory.make_config(json.loads(row[0]), timestamp=row[1])
304-
except pydantic.ValidationError:
304+
except pydantic.ValidationError as e:
305305
# We catch this error so that the app can still run if there are invalid model configs in the database.
306306
# One reason that an invalid model config might be in the database is if someone had to rollback from a
307307
# newer version of the app that added a new model type.

invokeai/app/services/shared/invocation_context.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
1414
from invokeai.app.services.images.images_common import ImageDTO
1515
from invokeai.app.services.invocation_services import InvocationServices
16+
from invokeai.app.services.model_records import ModelRecordChanges
1617
from invokeai.app.services.model_records.model_records_base import UnknownModelException
1718
from invokeai.app.util.step_callback import stable_diffusion_step_callback
1819
from invokeai.backend.model_manager.config import (
@@ -463,6 +464,20 @@ def download_and_cache_model(
463464
"""
464465
return self._services.model_manager.install.download_and_cache_model(source=source)
465466

467+
def import_local_model(
468+
self,
469+
model_path: Path,
470+
config: Optional[ModelRecordChanges] = None,
471+
access_token: Optional[str] = None,
472+
inplace: Optional[bool] = False,
473+
):
474+
"""
475+
TODO: Fill out description of this method
476+
"""
477+
if not model_path.exists():
478+
raise Exception("Models provided to import_local_model must already exist on disk")
479+
return self._services.model_manager.install.heuristic_import(str(model_path), config=config, access_token=access_token, inplace=inplace)
480+
466481
def load_local_model(
467482
self,
468483
model_path: Path,

invokeai/backend/flux/math.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
from einops import rearrange
3+
from torch import Tensor
4+
5+
6+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7+
q, k = apply_rope(q, k, pe)
8+
9+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10+
x = rearrange(x, "B H L D -> B L (H D)")
11+
12+
return x
13+
14+
15+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
16+
assert dim % 2 == 0
17+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
18+
omega = 1.0 / (theta**scale)
19+
out = torch.einsum("...n,d->...nd", pos, omega)
20+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
21+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
22+
return out.float()
23+
24+
25+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
26+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
27+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
28+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
29+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
30+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

invokeai/backend/flux/model.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
from torch import Tensor, nn
5+
6+
from invokeai.backend.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
7+
MLPEmbedder, SingleStreamBlock,
8+
timestep_embedding)
9+
10+
@dataclass
11+
class FluxParams:
12+
in_channels: int
13+
vec_in_dim: int
14+
context_in_dim: int
15+
hidden_size: int
16+
mlp_ratio: float
17+
num_heads: int
18+
depth: int
19+
depth_single_blocks: int
20+
axes_dim: list[int]
21+
theta: int
22+
qkv_bias: bool
23+
guidance_embed: bool
24+
25+
26+
class Flux(nn.Module):
27+
"""
28+
Transformer model for flow matching on sequences.
29+
"""
30+
31+
def __init__(self, params: FluxParams):
32+
super().__init__()
33+
34+
self.params = params
35+
self.in_channels = params.in_channels
36+
self.out_channels = self.in_channels
37+
if params.hidden_size % params.num_heads != 0:
38+
raise ValueError(
39+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
40+
)
41+
pe_dim = params.hidden_size // params.num_heads
42+
if sum(params.axes_dim) != pe_dim:
43+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
44+
self.hidden_size = params.hidden_size
45+
self.num_heads = params.num_heads
46+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
47+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
48+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
49+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
50+
self.guidance_in = (
51+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
52+
)
53+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
54+
55+
self.double_blocks = nn.ModuleList(
56+
[
57+
DoubleStreamBlock(
58+
self.hidden_size,
59+
self.num_heads,
60+
mlp_ratio=params.mlp_ratio,
61+
qkv_bias=params.qkv_bias,
62+
)
63+
for _ in range(params.depth)
64+
]
65+
)
66+
67+
self.single_blocks = nn.ModuleList(
68+
[
69+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
70+
for _ in range(params.depth_single_blocks)
71+
]
72+
)
73+
74+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
75+
76+
def forward(
77+
self,
78+
img: Tensor,
79+
img_ids: Tensor,
80+
txt: Tensor,
81+
txt_ids: Tensor,
82+
timesteps: Tensor,
83+
y: Tensor,
84+
guidance: Tensor | None = None,
85+
) -> Tensor:
86+
if img.ndim != 3 or txt.ndim != 3:
87+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
88+
89+
# running on sequences img
90+
img = self.img_in(img)
91+
vec = self.time_in(timestep_embedding(timesteps, 256))
92+
if self.params.guidance_embed:
93+
if guidance is None:
94+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
95+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
96+
vec = vec + self.vector_in(y)
97+
txt = self.txt_in(txt)
98+
99+
ids = torch.cat((txt_ids, img_ids), dim=1)
100+
pe = self.pe_embedder(ids)
101+
102+
for block in self.double_blocks:
103+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
104+
105+
img = torch.cat((txt, img), 1)
106+
for block in self.single_blocks:
107+
img = block(img, vec=vec, pe=pe)
108+
img = img[:, txt.shape[1] :, ...]
109+
110+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
111+
return img

0 commit comments

Comments
 (0)