Skip to content

Commit b5f35ed

Browse files
committed
Manage quantization of models within the loader
1 parent 643e082 commit b5f35ed

File tree

9 files changed

+249
-291
lines changed

9 files changed

+249
-291
lines changed

invokeai/app/invocations/fields.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class FieldDescriptions:
126126
negative_cond = "Negative conditioning tensor"
127127
noise = "Noise tensor"
128128
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
129+
t5Encoder = "T5 tokenizer and text encoder"
129130
unet = "UNet (scheduler, LoRAs)"
130131
transformer = "Transformer"
131132
vae = "VAE"

invokeai/app/invocations/flux_text_encoder.py

Lines changed: 39 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
77

88
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
9-
from invokeai.app.invocations.fields import InputField
10-
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding, TFluxModelKeys
9+
from invokeai.app.invocations.model import CLIPField, T5EncoderField
10+
from invokeai.app.invocations.fields import InputField, FieldDescriptions, Input
11+
from invokeai.app.invocations.flux_text_to_image import FLUX_MODELS, QuantizedModelForTextEncoding
12+
from invokeai.app.invocations.model import CLIPField, T5EncoderField
1113
from invokeai.app.invocations.primitives import ConditioningOutput
1214
from invokeai.app.services.shared.invocation_context import InvocationContext
1315
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo
@@ -22,57 +24,59 @@
2224
version="1.0.0",
2325
)
2426
class FluxTextEncoderInvocation(BaseInvocation):
25-
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
26-
use_8bit: bool = InputField(
27-
default=False, description="Whether to quantize the transformer model to 8-bit precision."
27+
clip: CLIPField = InputField(
28+
title="CLIP",
29+
description=FieldDescriptions.clip,
30+
input=Input.Connection,
31+
)
32+
t5Encoder: T5EncoderField = InputField(
33+
title="T5EncoderField",
34+
description=FieldDescriptions.t5Encoder,
35+
input=Input.Connection,
2836
)
2937
positive_prompt: str = InputField(description="Positive prompt for text-to-image generation.")
3038

3139
# TODO(ryand): Should we create a new return type for this invocation? This ConditioningOutput is clearly not
3240
# compatible with other ConditioningOutputs.
3341
@torch.no_grad()
3442
def invoke(self, context: InvocationContext) -> ConditioningOutput:
35-
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
3643

37-
t5_embeddings, clip_embeddings = self._encode_prompt(context, model_path)
44+
t5_embeddings, clip_embeddings = self._encode_prompt(context)
3845
conditioning_data = ConditioningFieldData(
3946
conditionings=[FLUXConditioningInfo(clip_embeds=clip_embeddings, t5_embeds=t5_embeddings)]
4047
)
4148

4249
conditioning_name = context.conditioning.save(conditioning_data)
4350
return ConditioningOutput.build(conditioning_name)
51+
52+
def _encode_prompt(self, context: InvocationContext) -> tuple[torch.Tensor, torch.Tensor]:
53+
# TODO: Determine the T5 max sequence length based on the model.
54+
# if self.model == "flux-schnell":
55+
max_seq_len = 256
56+
# # elif self.model == "flux-dev":
57+
# # max_seq_len = 512
58+
# else:
59+
# raise ValueError(f"Unknown model: {self.model}")
60+
61+
# Load CLIP.
62+
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
63+
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
64+
65+
# Load T5.
66+
t5_tokenizer_info = context.models.load(self.t5Encoder.tokenizer)
67+
t5_text_encoder_info = context.models.load(self.t5Encoder.text_encoder)
4468

45-
def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tuple[torch.Tensor, torch.Tensor]:
46-
# Determine the T5 max sequence length based on the model.
47-
if self.model == "flux-schnell":
48-
max_seq_len = 256
49-
# elif self.model == "flux-dev":
50-
# max_seq_len = 512
51-
else:
52-
raise ValueError(f"Unknown model: {self.model}")
53-
54-
# Load the CLIP tokenizer.
55-
clip_tokenizer_path = flux_model_dir / "tokenizer"
56-
clip_tokenizer = CLIPTokenizer.from_pretrained(clip_tokenizer_path, local_files_only=True)
57-
assert isinstance(clip_tokenizer, CLIPTokenizer)
58-
59-
# Load the T5 tokenizer.
60-
t5_tokenizer_path = flux_model_dir / "tokenizer_2"
61-
t5_tokenizer = T5TokenizerFast.from_pretrained(t5_tokenizer_path, local_files_only=True)
62-
assert isinstance(t5_tokenizer, T5TokenizerFast)
63-
64-
clip_text_encoder_path = flux_model_dir / "text_encoder"
65-
t5_text_encoder_path = flux_model_dir / "text_encoder_2"
6669
with (
67-
context.models.load_local_model(
68-
model_path=clip_text_encoder_path, loader=self._load_flux_text_encoder
69-
) as clip_text_encoder,
70-
context.models.load_local_model(
71-
model_path=t5_text_encoder_path, loader=self._load_flux_text_encoder_2
72-
) as t5_text_encoder,
70+
clip_text_encoder_info as clip_text_encoder,
71+
t5_text_encoder_info as t5_text_encoder,
72+
clip_tokenizer_info as clip_tokenizer,
73+
t5_tokenizer_info as t5_tokenizer,
7374
):
7475
assert isinstance(clip_text_encoder, CLIPTextModel)
7576
assert isinstance(t5_text_encoder, T5EncoderModel)
77+
assert isinstance(clip_tokenizer, CLIPTokenizer)
78+
assert isinstance(t5_tokenizer, T5TokenizerFast)
79+
7680
pipeline = FluxPipeline(
7781
scheduler=None,
7882
vae=None,
@@ -85,7 +89,7 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
8589

8690
# prompt_embeds: T5 embeddings
8791
# pooled_prompt_embeds: CLIP embeddings
88-
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
92+
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
8993
prompt=self.positive_prompt,
9094
prompt_2=self.positive_prompt,
9195
device=TorchDevice.choose_torch_device(),
@@ -95,41 +99,3 @@ def _encode_prompt(self, context: InvocationContext, flux_model_dir: Path) -> tu
9599
assert isinstance(prompt_embeds, torch.Tensor)
96100
assert isinstance(pooled_prompt_embeds, torch.Tensor)
97101
return prompt_embeds, pooled_prompt_embeds
98-
99-
@staticmethod
100-
def _load_flux_text_encoder(path: Path) -> CLIPTextModel:
101-
model = CLIPTextModel.from_pretrained(path, local_files_only=True)
102-
assert isinstance(model, CLIPTextModel)
103-
return model
104-
105-
def _load_flux_text_encoder_2(self, path: Path) -> T5EncoderModel:
106-
if self.use_8bit:
107-
model_8bit_path = path / "quantized"
108-
if model_8bit_path.exists():
109-
# The quantized model exists, load it.
110-
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
111-
# something that we should be able to make much faster.
112-
q_model = QuantizedModelForTextEncoding.from_pretrained(model_8bit_path)
113-
114-
# Access the underlying wrapped model.
115-
# We access the wrapped model, even though it is private, because it simplifies the type checking by
116-
# always returning a T5EncoderModel from this function.
117-
model = q_model._wrapped
118-
else:
119-
# The quantized model does not exist yet, quantize and save it.
120-
# TODO(ryand): dtype?
121-
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
122-
assert isinstance(model, T5EncoderModel)
123-
124-
q_model = QuantizedModelForTextEncoding.quantize(model, weights=qfloat8)
125-
126-
model_8bit_path.mkdir(parents=True, exist_ok=True)
127-
q_model.save_pretrained(model_8bit_path)
128-
129-
# (See earlier comment about accessing the wrapped model.)
130-
model = q_model._wrapped
131-
else:
132-
model = T5EncoderModel.from_pretrained(path, local_files_only=True)
133-
134-
assert isinstance(model, T5EncoderModel)
135-
return model

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 20 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
77
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
88
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
9-
from invokeai.app.invocations.model import ModelIdentifierField
9+
from invokeai.app.invocations.model import TransformerField, VAEField
1010
from optimum.quanto import qfloat8
1111
from PIL import Image
1212
from transformers.models.auto import AutoModelForTextEncoding
@@ -49,14 +49,14 @@ class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
4949
class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
5050
"""Text-to-image generation using a FLUX model."""
5151

52-
flux_model: ModelIdentifierField = InputField(
53-
description="The Flux model",
54-
input=Input.Any,
55-
ui_type=UIType.FluxMainModel
52+
transformer: TransformerField = InputField(
53+
description=FieldDescriptions.unet,
54+
input=Input.Connection,
55+
title="Transformer",
5656
)
57-
model: TFluxModelKeys = InputField(description="The FLUX model to use for text-to-image generation.")
58-
use_8bit: bool = InputField(
59-
default=False, description="Whether to quantize the transformer model to 8-bit precision."
57+
vae: VAEField = InputField(
58+
description=FieldDescriptions.vae,
59+
input=Input.Connection,
6060
)
6161
positive_text_conditioning: ConditioningField = InputField(
6262
description=FieldDescriptions.positive_cond, input=Input.Connection
@@ -72,38 +72,38 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
7272

7373
@torch.no_grad()
7474
def invoke(self, context: InvocationContext) -> ImageOutput:
75-
model_path = context.models.download_and_cache_model(FLUX_MODELS[self.model])
7675

7776
# Load the conditioning data.
7877
cond_data = context.conditioning.load(self.positive_text_conditioning.conditioning_name)
7978
assert len(cond_data.conditionings) == 1
8079
flux_conditioning = cond_data.conditionings[0]
8180
assert isinstance(flux_conditioning, FLUXConditioningInfo)
8281

83-
latents = self._run_diffusion(context, model_path, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
84-
image = self._run_vae_decoding(context, model_path, latents)
82+
latents = self._run_diffusion(context, flux_conditioning.clip_embeds, flux_conditioning.t5_embeds)
83+
image = self._run_vae_decoding(context, latents)
8584
image_dto = context.images.save(image=image)
8685
return ImageOutput.build(image_dto)
8786

8887
def _run_diffusion(
8988
self,
9089
context: InvocationContext,
91-
flux_model_dir: Path,
9290
clip_embeddings: torch.Tensor,
9391
t5_embeddings: torch.Tensor,
9492
):
95-
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(flux_model_dir / "scheduler", local_files_only=True)
93+
scheduler_info = context.models.load(self.transformer.scheduler)
94+
transformer_info = context.models.load(self.transformer.transformer)
9695

9796
# HACK(ryand): Manually empty the cache. Currently we don't check the size of the model before loading it from
9897
# disk. Since the transformer model is large (24GB), there's a good chance that it will OOM on 32GB RAM systems
9998
# if the cache is not empty.
100-
context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
99+
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
101100

102-
transformer_path = flux_model_dir / "transformer"
103-
with context.models.load_local_model(
104-
model_path=transformer_path, loader=self._load_flux_transformer
105-
) as transformer:
101+
with (
102+
transformer_info as transformer,
103+
scheduler_info as scheduler
104+
):
106105
assert isinstance(transformer, FluxTransformer2DModel)
106+
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
107107

108108
flux_pipeline_with_transformer = FluxPipeline(
109109
scheduler=scheduler,
@@ -136,11 +136,10 @@ def _run_diffusion(
136136
def _run_vae_decoding(
137137
self,
138138
context: InvocationContext,
139-
flux_model_dir: Path,
140139
latents: torch.Tensor,
141140
) -> Image.Image:
142-
vae_path = flux_model_dir / "vae"
143-
with context.models.load_local_model(model_path=vae_path, loader=self._load_flux_vae) as vae:
141+
vae_info = context.models.load(self.vae.vae)
142+
with vae_info as vae:
144143
assert isinstance(vae, AutoencoderKL)
145144

146145
flux_pipeline_with_vae = FluxPipeline(
@@ -165,43 +164,3 @@ def _run_vae_decoding(
165164

166165
assert isinstance(image, Image.Image)
167166
return image
168-
169-
def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
170-
if self.use_8bit:
171-
model_8bit_path = path / "quantized"
172-
if model_8bit_path.exists():
173-
# The quantized model exists, load it.
174-
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
175-
# something that we should be able to make much faster.
176-
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
177-
178-
# Access the underlying wrapped model.
179-
# We access the wrapped model, even though it is private, because it simplifies the type checking by
180-
# always returning a FluxTransformer2DModel from this function.
181-
model = q_model._wrapped
182-
else:
183-
# The quantized model does not exist yet, quantize and save it.
184-
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
185-
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
186-
# here.
187-
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
188-
assert isinstance(model, FluxTransformer2DModel)
189-
190-
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
191-
192-
model_8bit_path.mkdir(parents=True, exist_ok=True)
193-
q_model.save_pretrained(model_8bit_path)
194-
195-
# (See earlier comment about accessing the wrapped model.)
196-
model = q_model._wrapped
197-
else:
198-
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
199-
200-
assert isinstance(model, FluxTransformer2DModel)
201-
return model
202-
203-
@staticmethod
204-
def _load_flux_vae(path: Path) -> AutoencoderKL:
205-
model = AutoencoderKL.from_pretrained(path, local_files_only=True)
206-
assert isinstance(model, AutoencoderKL)
207-
return model

invokeai/app/invocations/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ class TransformerField(BaseModel):
6565
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
6666
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
6767

68+
class T5EncoderField(BaseModel):
69+
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
70+
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
71+
6872

6973
class VAEField(BaseModel):
7074
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
@@ -133,8 +137,8 @@ class FluxModelLoaderOutput(BaseInvocationOutput):
133137
"""Flux base model loader output"""
134138

135139
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
136-
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 1")
137-
clip2: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP 2")
140+
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
141+
t5Encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5Encoder, title="T5 Encoder")
138142
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
139143

140144

@@ -166,7 +170,7 @@ def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
166170
return FluxModelLoaderOutput(
167171
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
168172
clip=CLIPField(tokenizer=tokenizer, text_encoder=text_encoder, loras=[], skipped_layers=0),
169-
clip2=CLIPField(tokenizer=tokenizer2, text_encoder=text_encoder2, loras=[], skipped_layers=0),
173+
t5Encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=text_encoder2),
170174
vae=VAEField(vae=vae),
171175
)
172176

invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@ def get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelTy
7878

7979
# TO DO: Add exception handling
8080
def _hf_definition_to_type(self, module: str, class_name: str) -> ModelMixin: # fix with correct type
81-
if module in ["diffusers", "transformers"]:
81+
if module in [
82+
"diffusers",
83+
"transformers",
84+
"invokeai.backend.quantization.fast_quantized_transformers_model",
85+
"invokeai.backend.quantization.fast_quantized_diffusion_model",
86+
]:
8287
res_type = sys.modules[module]
8388
else:
8489
res_type = sys.modules["diffusers"].pipelines

invokeai/backend/model_manager/load/model_util.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
1111
from diffusers.schedulers.scheduling_utils import SchedulerMixin
12-
from transformers import CLIPTokenizer
12+
from transformers import CLIPTokenizer, T5TokenizerFast
1313

1414
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
1515
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
@@ -48,6 +48,13 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
4848
),
4949
):
5050
return model.calc_size()
51+
elif isinstance(
52+
model,
53+
(
54+
T5TokenizerFast,
55+
),
56+
):
57+
return len(model)
5158
else:
5259
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
5360
# supported model types.

0 commit comments

Comments
 (0)