Skip to content

Commit b227b90

Browse files
committed
Use the FluxPipeline.encode_prompt() api rather than trying to run the two text encoders separately.
1 parent 3599a4a commit b227b90

File tree

1 file changed

+41
-45
lines changed

1 file changed

+41
-45
lines changed

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -43,70 +43,66 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
4343
def invoke(self, context: InvocationContext) -> ImageOutput:
4444
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
4545

46-
clip_embeddings = self._run_clip_text_encoder(context, model_path)
47-
t5_embeddings = self._run_t5_text_encoder(context, model_path)
46+
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
4847
latents = self._run_diffusion(context, model_path, clip_embeddings, t5_embeddings)
4948
image = self._run_vae_decoding(context, model_path, latents)
5049
image_dto = context.images.save(image=image)
5150
return ImageOutput.build(image_dto)
5251

53-
def _run_clip_text_encoder(self, context: InvocationContext, flux_model_dir: Path) -> torch.Tensor:
54-
"""Run the CLIP text encoder."""
55-
tokenizer_path = flux_model_dir / "tokenizer"
56-
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, local_files_only=True)
57-
assert isinstance(tokenizer, CLIPTokenizer)
58-
59-
text_encoder_path = flux_model_dir / "text_encoder"
60-
with context.models.load_local_model(
61-
model_path=text_encoder_path, loader=self._load_flux_text_encoder
62-
) as text_encoder:
63-
assert isinstance(text_encoder, CLIPTextModel)
64-
flux_pipeline_with_te = FluxPipeline(
65-
scheduler=None,
66-
vae=None,
67-
text_encoder=text_encoder,
68-
tokenizer=tokenizer,
69-
text_encoder_2=None,
70-
tokenizer_2=None,
71-
transformer=None,
72-
)
73-
74-
return flux_pipeline_with_te._get_clip_prompt_embeds(
75-
prompt=self.positive_prompt, device=TorchDevice.choose_torch_device()
76-
)
77-
78-
def _run_t5_text_encoder(self, context: InvocationContext, flux_model_dir: Path) -> torch.Tensor:
79-
"""Run the T5 text encoder."""
80-
52+
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
53+
# Determine the T5 max sequence lenght based on the model.
8154
if self.model == "flux-schnell":
8255
max_seq_len = 256
8356
# elif self.model == "flux-dev":
8457
# max_seq_len = 512
8558
else:
8659
raise ValueError(f"Unknown model: {self.model}")
8760

88-
tokenizer_path = flux_model_dir / "tokenizer_2"
89-
tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_path, local_files_only=True)
90-
assert isinstance(tokenizer_2, T5TokenizerFast)
91-
92-
text_encoder_path = flux_model_dir / "text_encoder_2"
93-
with context.models.load_local_model(
94-
model_path=text_encoder_path, loader=self._load_flux_text_encoder_2
95-
) as text_encoder_2:
96-
flux_pipeline_with_te2 = FluxPipeline(
61+
# Load the CLIP tokenizer.
62+
clip_tokenizer_path = flux_model_dir / "tokenizer"
63+
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
64+
assert isinstance(clip_tokenizer, CLIPTokenizer)
65+
66+
# Load the T5 tokenizer.
67+
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
68+
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
69+
assert isinstance(t5_tokenizer, T5TokenizerFast)
70+
71+
clip_text_encoder_path = flux_model_dir / "text_encoder"
72+
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
73+
with (
74+
context.models.load_local_model(
75+
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
76+
) as clip_text_encoder,
77+
context.models.load_local_model(
78+
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
79+
) as t5_text_encoder,
80+
):
81+
assert isinstance(clip_text_encoder, CLIPTextModel)
82+
assert isinstance(t5_text_encoder, T5EncoderModel)
83+
pipeline = FluxPipeline(
9784
scheduler=None,
9885
vae=None,
99-
text_encoder=None,
100-
tokenizer=None,
101-
text_encoder_2=text_encoder_2,
102-
tokenizer_2=tokenizer_2,
86+
text_encoder=clip_text_encoder,
87+
tokenizer=clip_tokenizer,
88+
text_encoder_2=t5_text_encoder,
89+
tokenizer_2=t5_tokenizer,
10390
transformer=None,
10491
)
10592

106-
return flux_pipeline_with_te2._get_t5_prompt_embeds(
107-
prompt=self.positive_prompt, max_sequence_length=max_seq_len, device=TorchDevice.choose_torch_device()
93+
# prompt_embeds: T5 embeddings
94+
# pooled_prompt_embeds: CLIP embeddings
95+
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
96+
prompt=self.positive_prompt,
97+
prompt_2=self.positive_prompt,
98+
device=TorchDevice.choose_torch_device(),
99+
max_sequence_length=max_seq_len,
108100
)
109101

102+
assert isinstance(prompt_embeds, torch.Tensor)
103+
assert isinstance(pooled_prompt_embeds, torch.Tensor)
104+
return prompt_embeds, pooled_prompt_embeds
105+
110106
def _run_diffusion(
111107
self,
112108
context: InvocationContext,

0 commit comments

Comments
 (0)