Skip to content

Commit e6ff748

Browse files
committed
Minor improvements to FLUX workflow.
1 parent 89a652c commit e6ff748

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
3333

3434
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
3535
use_8bit: bool = InputField(
36-
default=False, description="Whether to quantize the T5 model and transformer model to 8-bit precision."
36+
default=False, description="Whether to quantize the transformer model to 8-bit precision."
3737
)
3838
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
3939
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
@@ -56,7 +56,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
5656
return ImageOutput.build(image_dto)
5757

5858
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
59-
# Determine the T5 max sequence lenght based on the model.
59+
# Determine the T5 max sequence length based on the model.
6060
if self.model == "flux-schnell":
6161
max_seq_len = 256
6262
# elif self.model == "flux-dev":
@@ -118,7 +118,9 @@ def _run_diffusion(
118118
):
119119
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
120120

121-
# HACK(ryand): Manually empty the cache.
121+
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
122+
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
123+
# if the cache is not empty.
122124
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
123125

124126
transformer_path = flux_model_dir / "transformer"
@@ -137,7 +139,7 @@ def _run_diffusion(
137139
transformer=transformer,
138140
)
139141

140-
return flux_pipeline_with_transformer(
142+
latents = flux_pipeline_with_transformer(
141143
height=self.height,
142144
width=self.width,
143145
num_inference_steps=self.num_steps,
@@ -149,6 +151,9 @@ def _run_diffusion(
149151
return_dict=False,
150152
)[0]
151153

154+
assert isinstance(latents, torch.Tensor)
155+
return latents
156+
152157
def _run_vae_decoding(
153158
self,
154159
context: InvocationContext,
@@ -201,16 +206,24 @@ def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
201206
model_8bit_map_path = model_8bit_path / "quantization_map.json"
202207
if model_8bit_path.exists():
203208
# The quantized model exists, load it.
204-
with torch.device("meta"):
205-
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True)
206-
assert isinstance(model, FluxTransformer2DModel)
209+
# TODO(ryand): Make loading from quantized model work properly.
210+
# Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210
211+
model = FluxTransformer2DModel.from_pretrained(
212+
path,
213+
local_files_only=True,
214+
)
215+
assert isinstance(model, FluxTransformer2DModel)
216+
model = model.to(device=torch.device("meta"))
207217

208218
state_dict = load_file(model_8bit_weights_path)
209219
with open(model_8bit_map_path, "r") as f:
210220
quant_map = json.load(f)
211221
requantize(model=model, state_dict=state_dict, quantization_map=quant_map)
212222
else:
213223
# The quantized model does not exist yet, quantize and save it.
224+
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
225+
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
226+
# here.
214227
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
215228
assert isinstance(model, FluxTransformer2DModel)
216229

@@ -222,9 +235,7 @@ def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
222235
with open(model_8bit_map_path, "w") as f:
223236
json.dump(quantization_map(model), f)
224237
else:
225-
model = FluxTransformer2DModel.from_pretrained(
226-
path, local_files_only=True, torch_dtype=TorchDevice.choose_torch_dtype()
227-
)
238+
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
228239

229240
assert isinstance(model, FluxTransformer2DModel)
230241
return model

0 commit comments

Comments
 (0)