@@ -33,7 +33,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
33
33
34
34
model : TFluxModelKeys = InputField (description = "The FLUX model to use for text-to-image generation." )
35
35
use_8bit : bool = InputField (
36
- default = False , description = "Whether to quantize the T5 model and transformer model to 8-bit precision."
36
+ default = False , description = "Whether to quantize the transformer model to 8-bit precision."
37
37
)
38
38
positive_prompt : str = InputField (description = "Positive prompt for text-to-image generation." )
39
39
width : int = InputField (default = 1024 , multiple_of = 16 , description = "Width of the generated image." )
@@ -56,7 +56,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
56
56
return ImageOutput .build (image_dto )
57
57
58
58
def _encode_prompt (self , context : InvocationContext , flux_model_dir : Path ) -> tuple [torch .Tensor , torch .Tensor ]:
59
- # Determine the T5 max sequence lenght based on the model.
59
+ # Determine the T5 max sequence length based on the model.
60
60
if self .model == "flux-schnell" :
61
61
max_seq_len = 256
62
62
# elif self.model == "flux-dev":
@@ -118,7 +118,9 @@ def _run_diffusion(
118
118
):
119
119
scheduler = FlowMatchEulerDiscreteScheduler .from_pretrained (flux_model_dir / "scheduler" , local_files_only = True )
120
120
121
- # HACK(ryand): Manually empty the cache.
121
+ # HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
122
+ # disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
123
+ # if the cache is not empty.
122
124
context .models ._services .model_manager .load .ram_cache .make_room (24 * 2 ** 30 )
123
125
124
126
transformer_path = flux_model_dir / "transformer"
@@ -137,7 +139,7 @@ def _run_diffusion(
137
139
transformer = transformer ,
138
140
)
139
141
140
- return flux_pipeline_with_transformer (
142
+ latents = flux_pipeline_with_transformer (
141
143
height = self .height ,
142
144
width = self .width ,
143
145
num_inference_steps = self .num_steps ,
@@ -149,6 +151,9 @@ def _run_diffusion(
149
151
return_dict = False ,
150
152
)[0 ]
151
153
154
+ assert isinstance (latents , torch .Tensor )
155
+ return latents
156
+
152
157
def _run_vae_decoding (
153
158
self ,
154
159
context : InvocationContext ,
@@ -201,16 +206,24 @@ def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
201
206
model_8bit_map_path = model_8bit_path / "quantization_map.json"
202
207
if model_8bit_path .exists ():
203
208
# The quantized model exists, load it.
204
- with torch .device ("meta" ):
205
- model = FluxTransformer2DModel .from_pretrained (path , local_files_only = True )
206
- assert isinstance (model , FluxTransformer2DModel )
209
+ # TODO(ryand): Make loading from quantized model work properly.
210
+ # Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210
211
+ model = FluxTransformer2DModel .from_pretrained (
212
+ path ,
213
+ local_files_only = True ,
214
+ )
215
+ assert isinstance (model , FluxTransformer2DModel )
216
+ model = model .to (device = torch .device ("meta" ))
207
217
208
218
state_dict = load_file (model_8bit_weights_path )
209
219
with open (model_8bit_map_path , "r" ) as f :
210
220
quant_map = json .load (f )
211
221
requantize (model = model , state_dict = state_dict , quantization_map = quant_map )
212
222
else :
213
223
# The quantized model does not exist yet, quantize and save it.
224
+ # TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
225
+ # GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
226
+ # here.
214
227
model = FluxTransformer2DModel .from_pretrained (path , local_files_only = True , torch_dtype = torch .bfloat16 )
215
228
assert isinstance (model , FluxTransformer2DModel )
216
229
@@ -222,9 +235,7 @@ def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
222
235
with open (model_8bit_map_path , "w" ) as f :
223
236
json .dump (quantization_map (model ), f )
224
237
else :
225
- model = FluxTransformer2DModel .from_pretrained (
226
- path , local_files_only = True , torch_dtype = TorchDevice .choose_torch_dtype ()
227
- )
238
+ model = FluxTransformer2DModel .from_pretrained (path , local_files_only = True , torch_dtype = torch .bfloat16 )
228
239
229
240
assert isinstance (model , FluxTransformer2DModel )
230
241
return model
0 commit comments