Skip to content

Commit f0baf88

Browse files
committed
Split a FluxTextEncoderInvocation out from the FluxTextToImageInvocation. This has the advantage that we benfit from automatic caching when the prompt isn't changed.
1 parent a8a2fc1 commit f0baf88

File tree

3 files changed

+165
-103
lines changed

3 files changed

+165
-103
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from pathlib import Path
2+
3+
import torch
4+
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
5+
from optimum.quanto import qfloat8
6+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
7+
8+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
9+
from invokeai.app.invocations.fields import InputField
10+
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding, TFluxModelKeys
11+
from invokeai.app.invocations.primitives import ConditioningOutput
12+
from invokeai.app.services.shared.invocation_context import InvocationContext
13+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
14+
from invokeai.backend.util.devices import TorchDevice
15+
16+
17+
@invocation(
18+
"flux_text_encoder",
19+
title="FLUX Text Encoding",
20+
tags=["image"],
21+
category="image",
22+
version="1.0.0",
23+
)
24+
class FluxTextEncoderInvocation(BaseInvocation):
25+
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
26+
use_8bit: bool = InputField(
27+
default=False, description="Whether to quantize the transformer model to 8-bit precision."
28+
)
29+
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
30+
31+
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
32+
# compatible with other ConditioningOutputs.
33+
@torch.no_grad()
34+
def invoke(self, context: InvocationContext) -> ConditioningOutput:
35+
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
36+
37+
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
38+
conditioning_data = ConditioningFieldData(
39+
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
40+
)
41+
42+
conditioning_name = context.conditioning.save(conditioning_data)
43+
return ConditioningOutput.build(conditioning_name)
44+
45+
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
46+
# Determine the T5 max sequence length based on the model.
47+
if self.model == "flux-schnell":
48+
max_seq_len = 256
49+
# elif self.model == "flux-dev":
50+
# max_seq_len = 512
51+
else:
52+
raise ValueError(f"Unknown model: {self.model}")
53+
54+
# Load the CLIP tokenizer.
55+
clip_tokenizer_path = flux_model_dir / "tokenizer"
56+
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
57+
assert isinstance(clip_tokenizer, CLIPTokenizer)
58+
59+
# Load the T5 tokenizer.
60+
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
61+
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
62+
assert isinstance(t5_tokenizer, T5TokenizerFast)
63+
64+
clip_text_encoder_path = flux_model_dir / "text_encoder"
65+
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
66+
with (
67+
context.models.load_local_model(
68+
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
69+
) as clip_text_encoder,
70+
context.models.load_local_model(
71+
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
72+
) as t5_text_encoder,
73+
):
74+
assert isinstance(clip_text_encoder, CLIPTextModel)
75+
assert isinstance(t5_text_encoder, T5EncoderModel)
76+
pipeline = FluxPipeline(
77+
scheduler=None,
78+
vae=None,
79+
text_encoder=clip_text_encoder,
80+
tokenizer=clip_tokenizer,
81+
text_encoder_2=t5_text_encoder,
82+
tokenizer_2=t5_tokenizer,
83+
transformer=None,
84+
)
85+
86+
# prompt_embeds: T5 embeddings
87+
# pooled_prompt_embeds: CLIP embeddings
88+
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
89+
prompt=self.positive_prompt,
90+
prompt_2=self.positive_prompt,
91+
device=TorchDevice.choose_torch_device(),
92+
max_sequence_length=max_seq_len,
93+
)
94+
95+
assert isinstance(prompt_embeds, torch.Tensor)
96+
assert isinstance(pooled_prompt_embeds, torch.Tensor)
97+
return prompt_embeds, pooled_prompt_embeds
98+
99+
@staticmethod
100+
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
101+
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
102+
assert isinstance(model, CLIPTextModel)
103+
return model
104+
105+
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
106+
if self.use_8bit:
107+
model_8bit_path = path / "quantized"
108+
if model_8bit_path.exists():
109+
# The quantized model exists, load it.
110+
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
111+
# something that we should be able to make much faster.
112+
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
113+
114+
# Access the underlying wrapped model.
115+
# We access the wrapped model, even though it is private, because it simplifies the type checking by
116+
# always returning a T5EncoderModel from this function.
117+
model = q_model._wrapped
118+
else:
119+
# The quantized model does not exist yet, quantize and save it.
120+
# TODO(ryand): dtype?
121+
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
122+
assert isinstance(model, T5EncoderModel)
123+
124+
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
125+
126+
model_8bit_path.mkdir(parents=True, exist_ok=True)
127+
q_model.save_pretrained(model_8bit_path)
128+
129+
# (See earlier comment about accessing the wrapped model.)
130+
model = q_model._wrapped
131+
else:
132+
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
133+
134+
assert isinstance(model, T5EncoderModel)
135+
return model

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 19 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,22 @@
77
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
88
from optimum.quanto import qfloat8
99
from PIL import Image
10-
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
1110
from transformers.models.auto import AutoModelForTextEncoding
1211

1312
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
14-
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
13+
from invokeai.app.invocations.fields import (
14+
ConditioningField,
15+
FieldDescriptions,
16+
Input,
17+
InputField,
18+
WithBoard,
19+
WithMetadata,
20+
)
1521
from invokeai.app.invocations.primitives import ImageOutput
1622
from invokeai.app.services.shared.invocation_context import InvocationContext
1723
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
1824
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
19-
from invokeai.backend.util.devices import TorchDevice
25+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
2026

2127
TFluxModelKeys = Literal["flux-schnell"]
2228
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
@@ -44,7 +50,9 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
4450
use_8bit: bool = InputField(
4551
default=False, description="Whether to quantize the transformer model to 8-bit precision."
4652
)
47-
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
53+
positive_text_conditioning: ConditioningField = InputField(
54+
description=FieldDescriptions.positive_cond, input=Input.Connection
55+
)
4856
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
4957
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
5058
num_steps: int = InputField(default=4, description="Number of diffusion steps.")
@@ -58,66 +66,17 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
5866
def invoke(self, context: InvocationContext) -> ImageOutput:
5967
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
6068

61-
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
62-
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
69+
# Load the conditioning data.
70+
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
71+
assert len(cond_data.conditionings) == 1
72+
flux_conditioning = cond_data.conditionings[0]
73+
assert isinstance(flux_conditioning, FLUXConditioningInfo)
74+
75+
latents = self._run_diffusion(context, model_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
6376
image = self._run_vae_decoding(context, model_path, latents)
6477
image_dto = context.images.save(image=image)
6578
return ImageOutput.build(image_dto)
6679

67-
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
68-
# Determine the T5 max sequence length based on the model.
69-
if self.model == "flux-schnell":
70-
max_seq_len = 256
71-
# elif self.model == "flux-dev":
72-
# max_seq_len = 512
73-
else:
74-
raise ValueError(f"Unknown model: {self.model}")
75-
76-
# Load the CLIP tokenizer.
77-
clip_tokenizer_path = flux_model_dir / "tokenizer"
78-
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
79-
assert isinstance(clip_tokenizer, CLIPTokenizer)
80-
81-
# Load the T5 tokenizer.
82-
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
83-
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
84-
assert isinstance(t5_tokenizer, T5TokenizerFast)
85-
86-
clip_text_encoder_path = flux_model_dir / "text_encoder"
87-
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
88-
with (
89-
context.models.load_local_model(
90-
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
91-
) as clip_text_encoder,
92-
context.models.load_local_model(
93-
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
94-
) as t5_text_encoder,
95-
):
96-
assert isinstance(clip_text_encoder, CLIPTextModel)
97-
assert isinstance(t5_text_encoder, T5EncoderModel)
98-
pipeline = FluxPipeline(
99-
scheduler=None,
100-
vae=None,
101-
text_encoder=clip_text_encoder,
102-
tokenizer=clip_tokenizer,
103-
text_encoder_2=t5_text_encoder,
104-
tokenizer_2=t5_tokenizer,
105-
transformer=None,
106-
)
107-
108-
# prompt_embeds: T5 embeddings
109-
# pooled_prompt_embeds: CLIP embeddings
110-
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
111-
prompt=self.positive_prompt,
112-
prompt_2=self.positive_prompt,
113-
device=TorchDevice.choose_torch_device(),
114-
max_sequence_length=max_seq_len,
115-
)
116-
117-
assert isinstance(prompt_embeds, torch.Tensor)
118-
assert isinstance(pooled_prompt_embeds, torch.Tensor)
119-
return prompt_embeds, pooled_prompt_embeds
120-
12180
def _run_diffusion(
12281
self,
12382
context: InvocationContext,
@@ -199,44 +158,6 @@ def _run_vae_decoding(
199158
assert isinstance(image, Image.Image)
200159
return image
201160

202-
@staticmethod
203-
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
204-
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
205-
assert isinstance(model, CLIPTextModel)
206-
return model
207-
208-
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
209-
if self.use_8bit:
210-
model_8bit_path = path / "quantized"
211-
if model_8bit_path.exists():
212-
# The quantized model exists, load it.
213-
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
214-
# something that we should be able to make much faster.
215-
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
216-
217-
# Access the underlying wrapped model.
218-
# We access the wrapped model, even though it is private, because it simplifies the type checking by
219-
# always returning a T5EncoderModel from this function.
220-
model = q_model._wrapped
221-
else:
222-
# The quantized model does not exist yet, quantize and save it.
223-
# TODO(ryand): dtype?
224-
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
225-
assert isinstance(model, T5EncoderModel)
226-
227-
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
228-
229-
model_8bit_path.mkdir(parents=True, exist_ok=True)
230-
q_model.save_pretrained(model_8bit_path)
231-
232-
# (See earlier comment about accessing the wrapped model.)
233-
model = q_model._wrapped
234-
else:
235-
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
236-
237-
assert isinstance(model, T5EncoderModel)
238-
return model
239-
240161
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
241162
if self.use_8bit:
242163
model_8bit_path = path / "quantized"

invokeai/backend/stable_diffusion/diffusion/conditioning_data.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@ def to(self, device, dtype=None):
2525
return self
2626

2727

28-
@dataclass
29-
class ConditioningFieldData:
30-
conditionings: List[BasicConditioningInfo]
31-
32-
3328
@dataclass
3429
class SDXLConditioningInfo(BasicConditioningInfo):
3530
"""SDXL text conditioning information produced by Compel."""
@@ -43,6 +38,17 @@ def to(self, device, dtype=None):
4338
return super().to(device=device, dtype=dtype)
4439

4540

41+
@dataclass
42+
class FLUXConditioningInfo:
43+
clip_embeds: torch.Tensor
44+
t5_embeds: torch.Tensor
45+
46+
47+
@dataclass
48+
class ConditioningFieldData:
49+
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
50+
51+
4652
@dataclass
4753
class IPAdapterConditioningInfo:
4854
cond_image_prompt_embeds: torch.Tensor

0 commit comments

Comments
 (0)