@@ -43,70 +43,66 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
43
43
def invoke (self , context : InvocationContext ) -> ImageOutput :
44
44
model_path = context .models .download_and_cache_model (FLUX_MODELS [self .model ])
45
45
46
- clip_embeddings = self ._run_clip_text_encoder (context , model_path )
47
- t5_embeddings = self ._run_t5_text_encoder (context , model_path )
46
+ t5_embeddings , clip_embeddings = self ._encode_prompt (context , model_path )
48
47
latents = self ._run_diffusion (context , model_path , clip_embeddings , t5_embeddings )
49
48
image = self ._run_vae_decoding (context , model_path , latents )
50
49
image_dto = context .images .save (image = image )
51
50
return ImageOutput .build (image_dto )
52
51
53
- def _run_clip_text_encoder (self , context : InvocationContext , flux_model_dir : Path ) -> torch .Tensor :
54
- """Run the CLIP text encoder."""
55
- tokenizer_path = flux_model_dir / "tokenizer"
56
- tokenizer = CLIPTokenizer .from_pretrained (tokenizer_path , local_files_only = True )
57
- assert isinstance (tokenizer , CLIPTokenizer )
58
-
59
- text_encoder_path = flux_model_dir / "text_encoder"
60
- with context .models .load_local_model (
61
- model_path = text_encoder_path , loader = self ._load_flux_text_encoder
62
- ) as text_encoder :
63
- assert isinstance (text_encoder , CLIPTextModel )
64
- flux_pipeline_with_te = FluxPipeline (
65
- scheduler = None ,
66
- vae = None ,
67
- text_encoder = text_encoder ,
68
- tokenizer = tokenizer ,
69
- text_encoder_2 = None ,
70
- tokenizer_2 = None ,
71
- transformer = None ,
72
- )
73
-
74
- return flux_pipeline_with_te ._get_clip_prompt_embeds (
75
- prompt = self .positive_prompt , device = TorchDevice .choose_torch_device ()
76
- )
77
-
78
- def _run_t5_text_encoder (self , context : InvocationContext , flux_model_dir : Path ) -> torch .Tensor :
79
- """Run the T5 text encoder."""
80
-
52
+ def _encode_prompt (self , context : InvocationContext , flux_model_dir : Path ) -> tuple [torch .Tensor , torch .Tensor ]:
53
+ # Determine the T5 max sequence lenght based on the model.
81
54
if self .model == "flux-schnell" :
82
55
max_seq_len = 256
83
56
# elif self.model == "flux-dev":
84
57
# max_seq_len = 512
85
58
else :
86
59
raise ValueError (f"Unknown model: { self .model } " )
87
60
88
- tokenizer_path = flux_model_dir / "tokenizer_2"
89
- tokenizer_2 = T5TokenizerFast .from_pretrained (tokenizer_path , local_files_only = True )
90
- assert isinstance (tokenizer_2 , T5TokenizerFast )
91
-
92
- text_encoder_path = flux_model_dir / "text_encoder_2"
93
- with context .models .load_local_model (
94
- model_path = text_encoder_path , loader = self ._load_flux_text_encoder_2
95
- ) as text_encoder_2 :
96
- flux_pipeline_with_te2 = FluxPipeline (
61
+ # Load the CLIP tokenizer.
62
+ clip_tokenizer_path = flux_model_dir / "tokenizer"
63
+ clip_tokenizer = CLIPTokenizer .from_pretrained (clip_tokenizer_path , local_files_only = True )
64
+ assert isinstance (clip_tokenizer , CLIPTokenizer )
65
+
66
+ # Load the T5 tokenizer.
67
+ t5_tokenizer_path = flux_model_dir / "tokenizer_2"
68
+ t5_tokenizer = T5TokenizerFast .from_pretrained (t5_tokenizer_path , local_files_only = True )
69
+ assert isinstance (t5_tokenizer , T5TokenizerFast )
70
+
71
+ clip_text_encoder_path = flux_model_dir / "text_encoder"
72
+ t5_text_encoder_path = flux_model_dir / "text_encoder_2"
73
+ with (
74
+ context .models .load_local_model (
75
+ model_path = clip_text_encoder_path , loader = self ._load_flux_text_encoder
76
+ ) as clip_text_encoder ,
77
+ context .models .load_local_model (
78
+ model_path = t5_text_encoder_path , loader = self ._load_flux_text_encoder_2
79
+ ) as t5_text_encoder ,
80
+ ):
81
+ assert isinstance (clip_text_encoder , CLIPTextModel )
82
+ assert isinstance (t5_text_encoder , T5EncoderModel )
83
+ pipeline = FluxPipeline (
97
84
scheduler = None ,
98
85
vae = None ,
99
- text_encoder = None ,
100
- tokenizer = None ,
101
- text_encoder_2 = text_encoder_2 ,
102
- tokenizer_2 = tokenizer_2 ,
86
+ text_encoder = clip_text_encoder ,
87
+ tokenizer = clip_tokenizer ,
88
+ text_encoder_2 = t5_text_encoder ,
89
+ tokenizer_2 = t5_tokenizer ,
103
90
transformer = None ,
104
91
)
105
92
106
- return flux_pipeline_with_te2 ._get_t5_prompt_embeds (
107
- prompt = self .positive_prompt , max_sequence_length = max_seq_len , device = TorchDevice .choose_torch_device ()
93
+ # prompt_embeds: T5 embeddings
94
+ # pooled_prompt_embeds: CLIP embeddings
95
+ prompt_embeds , pooled_prompt_embeds , text_ids = pipeline .encode_prompt (
96
+ prompt = self .positive_prompt ,
97
+ prompt_2 = self .positive_prompt ,
98
+ device = TorchDevice .choose_torch_device (),
99
+ max_sequence_length = max_seq_len ,
108
100
)
109
101
102
+ assert isinstance (prompt_embeds , torch .Tensor )
103
+ assert isinstance (pooled_prompt_embeds , torch .Tensor )
104
+ return prompt_embeds , pooled_prompt_embeds
105
+
110
106
def _run_diffusion (
111
107
self ,
112
108
context : InvocationContext ,
0 commit comments