Skip to content

Commit ce3ee76

Browse files
committed
Manage quantization of models within the loader
1 parent d7e0bd4 commit ce3ee76

File tree

9 files changed

+263
-349
lines changed

9 files changed

+263
-349
lines changed

invokeai/app/invocations/fields.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class FieldDescriptions:
126126
negative_cond = "Negative conditioning tensor"
127127
noise = "Noise tensor"
128128
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
129+
t5Encoder = "T5 tokenizer and text encoder"
129130
unet = "UNet (scheduler, LoRAs)"
130131
transformer = "Transformer"
131132
vae = "VAE"

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 39 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
77

88
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
9-
from invokeai.app.invocations.fields import InputField
10-
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding, TFluxModelKeys
9+
from invokeai.app.invocations.model import CLIPField, T5EncoderField
10+
from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input
11+
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding
12+
from invokeai.app.invocations.model import CLIPField, T5EncoderField
1113
from invokeai.app.invocations.primitives import ConditioningOutput
1214
from invokeai.app.services.shared.invocation_context import InvocationContext
1315
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@@ -22,57 +24,59 @@
2224
version="1.0.0",
2325
)
2426
class FluxTextEncoderInvocation(BaseInvocation):
25-
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
26-
use_8bit: bool = InputField(
27-
default=False, description="Whether to quantize the transformer model to 8-bit precision."
27+
clip: CLIPField = InputField(
28+
title="CLIP",
29+
description=FieldDescriptions.clip,
30+
input=Input.Connection,
31+
)
32+
t5Encoder: T5EncoderField = InputField(
33+
title="T5EncoderField",
34+
description=FieldDescriptions.t5Encoder,
35+
input=Input.Connection,
2836
)
2937
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
3038

3139
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
3240
# compatible with other ConditioningOutputs.
3341
@torch.no_grad()
3442
def invoke(self, context: InvocationContext) -> ConditioningOutput:
35-
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
3643

37-
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
44+
t5_embeddings, clip_embeddings = self._encode_prompt(context)
3845
conditioning_data = ConditioningFieldData(
3946
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
4047
)
4148

4249
conditioning_name = context.conditioning.save(conditioning_data)
4350
return ConditioningOutput.build(conditioning_name)
51+
52+
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
53+
# TODO: Determine the T5 max sequence length based on the model.
54+
# if self.model == "flux-schnell":
55+
max_seq_len = 256
56+
# # elif self.model == "flux-dev":
57+
# # max_seq_len = 512
58+
# else:
59+
# raise ValueError(f"Unknown model: {self.model}")
60+
61+
# Load CLIP.
62+
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
63+
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
64+
65+
# Load T5.
66+
t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer)
67+
t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder)
4468

45-
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
46-
# Determine the T5 max sequence length based on the model.
47-
if self.model == "flux-schnell":
48-
max_seq_len = 256
49-
# elif self.model == "flux-dev":
50-
# max_seq_len = 512
51-
else:
52-
raise ValueError(f"Unknown model: {self.model}")
53-
54-
# Load the CLIP tokenizer.
55-
clip_tokenizer_path = flux_model_dir / "tokenizer"
56-
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
57-
assert isinstance(clip_tokenizer, CLIPTokenizer)
58-
59-
# Load the T5 tokenizer.
60-
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
61-
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
62-
assert isinstance(t5_tokenizer, T5TokenizerFast)
63-
64-
clip_text_encoder_path = flux_model_dir / "text_encoder"
65-
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
6669
with (
67-
context.models.load_local_model(
68-
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
69-
) as clip_text_encoder,
70-
context.models.load_local_model(
71-
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
72-
) as t5_text_encoder,
70+
clip_text_encoder_info as clip_text_encoder,
71+
t5_text_encoder_info as t5_text_encoder,
72+
clip_tokenizer_info as clip_tokenizer,
73+
t5_tokenizer_info as t5_tokenizer,
7374
):
7475
assert isinstance(clip_text_encoder, CLIPTextModel)
7576
assert isinstance(t5_text_encoder, T5EncoderModel)
77+
assert isinstance(clip_tokenizer, CLIPTokenizer)
78+
assert isinstance(t5_tokenizer, T5TokenizerFast)
79+
7680
pipeline = FluxPipeline(
7781
scheduler=None,
7882
vae=None,
@@ -85,7 +89,7 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
8589

8690
# prompt_embeds: T5 embeddings
8791
# pooled_prompt_embeds: CLIP embeddings
88-
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
92+
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
8993
prompt=self.positive_prompt,
9094
prompt_2=self.positive_prompt,
9195
device=TorchDevice.choose_torch_device(),
@@ -95,41 +99,3 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
9599
assert isinstance(prompt_embeds, torch.Tensor)
96100
assert isinstance(pooled_prompt_embeds, torch.Tensor)
97101
return prompt_embeds, pooled_prompt_embeds
98-
99-
@staticmethod
100-
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
101-
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
102-
assert isinstance(model, CLIPTextModel)
103-
return model
104-
105-
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
106-
if self.use_8bit:
107-
model_8bit_path = path / "quantized"
108-
if model_8bit_path.exists():
109-
# The quantized model exists, load it.
110-
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
111-
# something that we should be able to make much faster.
112-
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
113-
114-
# Access the underlying wrapped model.
115-
# We access the wrapped model, even though it is private, because it simplifies the type checking by
116-
# always returning a T5EncoderModel from this function.
117-
model = q_model._wrapped
118-
else:
119-
# The quantized model does not exist yet, quantize and save it.
120-
# TODO(ryand): dtype?
121-
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
122-
assert isinstance(model, T5EncoderModel)
123-
124-
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
125-
126-
model_8bit_path.mkdir(parents=True, exist_ok=True)
127-
q_model.save_pretrained(model_8bit_path)
128-
129-
# (See earlier comment about accessing the wrapped model.)
130-
model = q_model._wrapped
131-
else:
132-
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
133-
134-
assert isinstance(model, T5EncoderModel)
135-
return model

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 34 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
88
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
9-
from invokeai.app.invocations.model import ModelIdentifierField
9+
from invokeai.app.invocations.model import TransformerField, VAEField
1010
from optimum.quanto import qfloat8
1111
from PIL import Image
1212
from safetensors.torch import load_file
@@ -52,17 +52,14 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
5252
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
5353
"""Text-to-image generation using a FLUX model."""
5454

55-
flux_model: ModelIdentifierField = InputField(
56-
description="The Flux model",
57-
input=Input.Any,
58-
ui_type=UIType.FluxMainModel
55+
transformer: TransformerField = InputField(
56+
description=FieldDescriptions.unet,
57+
input=Input.Connection,
58+
title="Transformer",
5959
)
60-
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
61-
quantization_type: Literal["raw", "NF4", "llm_int8"] = InputField(
62-
default="raw", description="The type of quantization to use for the transformer model."
63-
)
64-
use_8bit: bool = InputField(
65-
default=False, description="Whether to quantize the transformer model to 8-bit precision."
60+
vae: VAEField = InputField(
61+
description=FieldDescriptions.vae,
62+
input=Input.Connection,
6663
)
6764
positive_text_conditioning: ConditioningField = InputField(
6865
description=FieldDescriptions.positive_cond, input=Input.Connection
@@ -78,70 +75,38 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
7875

7976
@torch.no_grad()
8077
def invoke(self, context: InvocationContext) -> ImageOutput:
81-
# model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
82-
flux_transformer_path = context.models.download_and_cache_model(
83-
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/flux1-schnell.safetensors"
84-
)
85-
flux_ae_path = context.models.download_and_cache_model(
86-
"https://huggingface.co/black-forest-labs/FLUX.1-schnell/resolve/main/ae.safetensors"
87-
)
8878

8979
# Load the conditioning data.
9080
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
9181
assert len(cond_data.conditionings) == 1
9282
flux_conditioning = cond_data.conditionings[0]
9383
assert isinstance(flux_conditioning, FLUXConditioningInfo)
9484

95-
latents = self._run_diffusion(
96-
context, flux_transformer_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds
97-
)
98-
image = self._run_vae_decoding(context, flux_ae_path, latents)
85+
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
86+
image = self._run_vae_decoding(context, latents)
9987
image_dto = context.images.save(image=image)
10088
return ImageOutput.build(image_dto)
10189

10290
def _run_diffusion(
10391
self,
10492
context: InvocationContext,
105-
flux_transformer_path: Path,
10693
clip_embeddings: torch.Tensor,
10794
t5_embeddings: torch.Tensor,
10895
):
109-
inference_dtype = TorchDevice.choose_torch_dtype()
110-
111-
# Prepare input noise.
112-
# TODO(ryand): Does the seed behave the same on different devices? Should we re-implement this to always use a
113-
# CPU RNG?
114-
x = get_noise(
115-
num_samples=1,
116-
height=self.height,
117-
width=self.width,
118-
device=TorchDevice.choose_torch_device(),
119-
dtype=inference_dtype,
120-
seed=self.seed,
121-
)
122-
123-
img, img_ids = self._prepare_latent_img_patches(x)
124-
125-
# HACK(ryand): Find a better way to determine if this is a schnell model or not.
126-
is_schnell = "shnell" in str(flux_transformer_path)
127-
timesteps = get_schedule(
128-
num_steps=self.num_steps,
129-
image_seq_len=img.shape[1],
130-
shift=not is_schnell,
131-
)
132-
133-
bs, t5_seq_len, _ = t5_embeddings.shape
134-
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
96+
scheduler_info = context.models.load(self.transformer.scheduler)
97+
transformer_info = context.models.load(self.transformer.transformer)
13598

13699
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
137100
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
138101
# if the cache is not empty.
139-
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
102+
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
140103

141-
with context.models.load_local_model(
142-
model_path=flux_transformer_path, loader=self._load_flux_transformer
143-
) as transformer:
144-
assert isinstance(transformer, Flux)
104+
with (
105+
transformer_info as transformer,
106+
scheduler_info as scheduler
107+
):
108+
assert isinstance(transformer, FluxTransformer2DModel)
109+
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
145110

146111
x = denoise(
147112
model=transformer,
@@ -185,75 +150,25 @@ def _prepare_latent_img_patches(self, latent_img: torch.Tensor) -> tuple[torch.T
185150
def _run_vae_decoding(
186151
self,
187152
context: InvocationContext,
188-
flux_ae_path: Path,
189153
latents: torch.Tensor,
190154
) -> Image.Image:
191-
with context.models.load_local_model(model_path=flux_ae_path, loader=self._load_flux_vae) as vae:
192-
assert isinstance(vae, AutoEncoder)
193-
# TODO(ryand): Test that this works with both float16 and bfloat16.
194-
with torch.autocast(device_type=latents.device.type, dtype=TorchDevice.choose_torch_dtype()):
195-
img = vae.decode(latents)
155+
vae_info = context.models.load(self.vae.vae)
156+
with vae_info as vae:
157+
assert isinstance(vae, AutoencoderKL)
196158

197159
img.clamp(-1, 1)
198160
img = rearrange(img[0], "c h w -> h w c")
199161
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
200162

201-
return img_pil
202-
203-
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
204-
inference_dtype = TorchDevice.choose_torch_dtype()
205-
if self.quantization_type == "raw":
206-
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
207-
params = flux_configs["flux-schnell"].params
208-
209-
# Initialize the model on the "meta" device.
210-
with accelerate.init_empty_weights():
211-
model = Flux(params).to(inference_dtype)
212-
213-
state_dict = load_file(path)
214-
# TODO(ryand): Cast the state_dict to the appropriate dtype?
215-
model.load_state_dict(state_dict, strict=True, assign=True)
216-
elif self.quantization_type == "NF4":
217-
model_path = path.parent / "bnb_nf4.safetensors"
218-
219-
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
220-
params = flux_configs["flux-schnell"].params
221-
# Initialize the model on the "meta" device.
222-
with accelerate.init_empty_weights():
223-
model = Flux(params)
224-
model = quantize_model_nf4(model, modules_to_not_convert=set(), compute_dtype=torch.bfloat16)
225-
226-
# TODO(ryand): Right now, some of the weights are loaded in bfloat16. Think about how best to handle
227-
# this on GPUs without bfloat16 support.
228-
state_dict = load_file(model_path)
229-
model.load_state_dict(state_dict, strict=True, assign=True)
230-
231-
elif self.quantization_type == "llm_int8":
232-
raise NotImplementedError("LLM int8 quantization is not yet supported.")
233-
# model_config = FluxTransformer2DModel.load_config(path, local_files_only=True)
234-
# with accelerate.init_empty_weights():
235-
# empty_model = FluxTransformer2DModel.from_config(model_config)
236-
# assert isinstance(empty_model, FluxTransformer2DModel)
237-
# model_int8_path = path / "bnb_llm_int8"
238-
# assert model_int8_path.exists()
239-
# with accelerate.init_empty_weights():
240-
# model = quantize_model_llm_int8(empty_model, modules_to_not_convert=set())
241-
242-
# sd = load_file(model_int8_path / "model.safetensors")
243-
# model.load_state_dict(sd, strict=True, assign=True)
244-
else:
245-
raise ValueError(f"Unsupported quantization type: {self.quantization_type}")
246-
247-
assert isinstance(model, FluxTransformer2DModel)
248-
return model
249-
250-
@staticmethod
251-
def _load_flux_vae(path: Path) -> AutoEncoder:
252-
# TODO(ryand): Determine if this is a schnell model or a dev model and load the appropriate config.
253-
ae_params = flux_configs["flux1-schnell"].ae_params
254-
with accelerate.init_empty_weights():
255-
ae = AutoEncoder(ae_params)
256-
257-
state_dict = load_file(path)
258-
ae.load_state_dict(state_dict, strict=True, assign=True)
259-
return ae
163+
latents = flux_pipeline_with_vae._unpack_latents(
164+
latents, self.height, self.width, flux_pipeline_with_vae.vae_scale_factor
165+
)
166+
latents = (
167+
latents / flux_pipeline_with_vae.vae.config.scaling_factor
168+
) + flux_pipeline_with_vae.vae.config.shift_factor
169+
latents = latents.to(dtype=vae.dtype)
170+
image = flux_pipeline_with_vae.vae.decode(latents, return_dict=False)[0]
171+
image = flux_pipeline_with_vae.image_processor.postprocess(image, output_type="pil")[0]
172+
173+
assert isinstance(image, Image.Image)
174+
return image

0 commit comments

Comments
 (0)