Skip to content

Commit 5dd619e

Browse files
committed
First draft of FluxTextToImageInvocation.
1 parent 7d447cb commit 5dd619e

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
from pathlib import Path
2+
from typing import Literal
3+
4+
import torch
5+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
6+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
7+
from diffusers.pipelines.flux import FluxPipeline
8+
from PIL import Image
9+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
10+
11+
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
12+
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
13+
from invokeai.app.invocations.primitives import ImageOutput
14+
from invokeai.app.services.shared.invocation_context import InvocationContext
15+
from invokeai.backend.util.devices import TorchDevice
16+
17+
TFluxModelKeys = Literal["flux-schnell"]
18+
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
19+
20+
21+
@invocation(
22+
"flux_text_to_image",
23+
title="FLUX Text to Image",
24+
tags=["image"],
25+
category="image",
26+
version="1.0.0",
27+
)
28+
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
29+
"""Text-to-image generation using a FLUX model."""
30+
31+
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
32+
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
33+
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
34+
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
35+
num_steps: int = InputField(default=4, description="Number of diffusion steps.")
36+
guidance: float = InputField(
37+
default=4.0,
38+
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images.",
39+
)
40+
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
41+
42+
@torch.no_grad()
43+
def invoke(self, context: InvocationContext) -> ImageOutput:
44+
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
45+
46+
clip_embeddings = self._run_clip_text_encoder(context, model_path)
47+
t5_embeddings = self._run_t5_text_encoder(context, model_path)
48+
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
49+
image = self._run_vae_decoding(context, model_path, latents)
50+
image_dto = context.images.save(image=image)
51+
return ImageOutput.build(image_dto)
52+
53+
def _run_clip_text_encoder(self, context: InvocationContext, flux_model_dir: Path) -> torch.Tensor:
54+
"""Run the CLIP text encoder."""
55+
tokenizer_path = flux_model_dir / "tokenizer"
56+
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
57+
assert isinstance(tokenizer, CLIPTokenizer)
58+
59+
text_encoder_path = flux_model_dir / "text_encoder"
60+
with context.models.load_local_model(
61+
model_path=text_encoder_path, loader=self._load_flux_text_encoder
62+
) as text_encoder:
63+
assert isinstance(text_encoder, CLIPTextModel)
64+
flux_pipeline_with_te = FluxPipeline(
65+
scheduler=None,
66+
vae=None,
67+
text_encoder=text_encoder,
68+
tokenizer=tokenizer,
69+
text_encoder_2=None,
70+
tokenizer_2=None,
71+
transformer=None,
72+
)
73+
74+
return flux_pipeline_with_te._get_clip_prompt_embeds(
75+
prompt=self.positive_prompt, device=TorchDevice.choose_torch_device()
76+
)
77+
78+
def _run_t5_text_encoder(self, context: InvocationContext, flux_model_dir: Path) -> torch.Tensor:
79+
"""Run the T5 text encoder."""
80+
81+
if self.model == "flux-schnell":
82+
max_seq_len = 256
83+
# elif self.model == "flux-dev":
84+
# max_seq_len = 512
85+
else:
86+
raise ValueError(f"Unknown model: {self.model}")
87+
88+
tokenizer_path = flux_model_dir / "tokenizer_2"
89+
tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_path, local_files_only=True)
90+
assert isinstance(tokenizer_2, T5TokenizerFast)
91+
92+
text_encoder_path = flux_model_dir / "text_encoder_2"
93+
with context.models.load_local_model(
94+
model_path=text_encoder_path, loader=self._load_flux_text_encoder_2
95+
) as text_encoder_2:
96+
flux_pipeline_with_te2 = FluxPipeline(
97+
scheduler=None,
98+
vae=None,
99+
text_encoder=None,
100+
tokenizer=None,
101+
text_encoder_2=text_encoder_2,
102+
tokenizer_2=tokenizer_2,
103+
transformer=None,
104+
)
105+
106+
return flux_pipeline_with_te2._get_t5_prompt_embeds(
107+
prompt=self.positive_prompt, max_sequence_length=max_seq_len, device=TorchDevice.choose_torch_device()
108+
)
109+
110+
def _run_diffusion(
111+
self,
112+
context: InvocationContext,
113+
flux_model_dir: Path,
114+
clip_embeddings: torch.Tensor,
115+
t5_embeddings: torch.Tensor,
116+
):
117+
scheduler = FlowMatchEulerDiscreteScheduler()
118+
119+
transformer_path = flux_model_dir / "transformer"
120+
with context.models.load_local_model(
121+
model_path=transformer_path, loader=self._load_flux_transformer
122+
) as transformer:
123+
assert isinstance(transformer, FluxTransformer2DModel)
124+
125+
flux_pipeline_with_transformer = FluxPipeline(
126+
scheduler=scheduler,
127+
vae=None,
128+
text_encoder=None,
129+
tokenizer=None,
130+
text_encoder_2=None,
131+
tokenizer_2=None,
132+
transformer=transformer,
133+
)
134+
135+
return flux_pipeline_with_transformer(
136+
height=self.height,
137+
width=self.width,
138+
num_inference_steps=self.num_steps,
139+
guidance_scale=self.guidance,
140+
generator=torch.Generator().manual_seed(self.seed),
141+
prompt_embeds=t5_embeddings,
142+
pooled_prompt_embeds=clip_embeddings,
143+
output_type="latent",
144+
return_dict=False,
145+
)[0]
146+
147+
def _run_vae_decoding(
148+
self,
149+
context: InvocationContext,
150+
flux_model_dir: Path,
151+
latent: torch.Tensor,
152+
) -> Image.Image:
153+
vae_path = flux_model_dir / "vae"
154+
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
155+
assert isinstance(vae, AutoencoderKL)
156+
157+
flux_pipeline_with_vae = FluxPipeline(
158+
scheduler=None,
159+
vae=vae,
160+
text_encoder=None,
161+
tokenizer=None,
162+
text_encoder_2=None,
163+
tokenizer_2=None,
164+
transformer=None,
165+
)
166+
167+
latents = flux_pipeline_with_vae._unpack_latents(
168+
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
169+
)
170+
latents = (
171+
latents / flux_pipeline_with_vae.vae.config.scaling_factor
172+
) + flux_pipeline_with_vae.vae.config.shift_factor
173+
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
174+
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")
175+
176+
assert isinstance(image, Image.Image)
177+
return image
178+
179+
@staticmethod
180+
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
181+
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
182+
assert isinstance(model, CLIPTextModel)
183+
return model
184+
185+
@staticmethod
186+
def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
187+
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
188+
assert isinstance(model, T5EncoderModel)
189+
return model
190+
191+
@staticmethod
192+
def _load_flux_transformer(path: Path) -> FluxTransformer2DModel:
193+
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True)
194+
assert isinstance(model, FluxTransformer2DModel)
195+
return model
196+
197+
@staticmethod
198+
def _load_flux_vae(path: Path) -> AutoencoderKL:
199+
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
200+
assert isinstance(model, AutoencoderKL)
201+
return model

0 commit comments

Comments
 (0)