Skip to content

Commit a8a2fc1

Browse files
committed
Make quantized loading fast for both T5XXL and FLUX transformer.
1 parent d23ad18 commit a8a2fc1

File tree

3 files changed

+142
-3
lines changed

3 files changed

+142
-3
lines changed

invokeai/app/invocations/flux_text_to_image.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
77
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
88
from optimum.quanto import qfloat8
9-
from optimum.quanto.models import QuantizedDiffusersModel, QuantizedTransformersModel
109
from PIL import Image
1110
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
1211
from transformers.models.auto import AutoModelForTextEncoding
@@ -15,17 +14,19 @@
1514
from invokeai.app.invocations.fields import InputField, WithBoard, WithMetadata
1615
from invokeai.app.invocations.primitives import ImageOutput
1716
from invokeai.app.services.shared.invocation_context import InvocationContext
17+
from invokeai.backend.quantization.fast_quantized_diffusion_model import FastQuantizedDiffusersModel
18+
from invokeai.backend.quantization.fast_quantized_transformers_model import FastQuantizedTransformersModel
1819
from invokeai.backend.util.devices import TorchDevice
1920

2021
TFluxModelKeys = Literal["flux-schnell"]
2122
FLUX_MODELS: dict[TFluxModelKeys, str] = {"flux-schnell": "black-forest-labs/FLUX.1-schnell"}
2223

2324

24-
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
25+
class QuantizedFluxTransformer2DModel(FastQuantizedDiffusersModel):
2526
base_class = FluxTransformer2DModel
2627

2728

28-
class QuantizedModelForTextEncoding(QuantizedTransformersModel):
29+
class QuantizedModelForTextEncoding(FastQuantizedTransformersModel):
2930
auto_class = AutoModelForTextEncoding
3031

3132

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import json
2+
import os
3+
from typing import Union
4+
5+
from diffusers.models.model_loading_utils import load_state_dict
6+
from diffusers.utils import (
7+
CONFIG_NAME,
8+
SAFE_WEIGHTS_INDEX_NAME,
9+
SAFETENSORS_WEIGHTS_NAME,
10+
_get_checkpoint_shard_files,
11+
is_accelerate_available,
12+
)
13+
from optimum.quanto.models import QuantizedDiffusersModel
14+
from optimum.quanto.models.shared_dict import ShardedStateDict
15+
16+
from invokeai.backend.requantize import requantize
17+
18+
19+
class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
20+
@classmethod
21+
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
22+
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
23+
if cls.base_class is None:
24+
raise ValueError("The `base_class` attribute needs to be configured.")
25+
26+
if not is_accelerate_available():
27+
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
28+
from accelerate import init_empty_weights
29+
30+
if os.path.isdir(model_name_or_path):
31+
# Look for a quantization map
32+
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
33+
if not os.path.exists(qmap_path):
34+
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
35+
36+
# Look for original model config file.
37+
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
38+
if not os.path.exists(model_config_path):
39+
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
40+
41+
with open(qmap_path, "r", encoding="utf-8") as f:
42+
qmap = json.load(f)
43+
44+
with open(model_config_path, "r", encoding="utf-8") as f:
45+
original_model_cls_name = json.load(f)["_class_name"]
46+
configured_cls_name = cls.base_class.__name__
47+
if configured_cls_name != original_model_cls_name:
48+
raise ValueError(
49+
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
50+
)
51+
52+
# Create an empty model
53+
config = cls.base_class.load_config(model_name_or_path)
54+
with init_empty_weights():
55+
model = cls.base_class.from_config(config)
56+
57+
# Look for the index of a sharded checkpoint
58+
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
59+
if os.path.exists(checkpoint_file):
60+
# Convert the checkpoint path to a list of shards
61+
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
62+
# Create a mapping for the sharded safetensor files
63+
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
64+
else:
65+
# Look for a single checkpoint file
66+
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
67+
if not os.path.exists(checkpoint_file):
68+
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
69+
# Get state_dict from model checkpoint
70+
state_dict = load_state_dict(checkpoint_file)
71+
72+
# Requantize and load quantized weights from state_dict
73+
requantize(model, state_dict=state_dict, quantization_map=qmap)
74+
model.eval()
75+
return cls(model)
76+
else:
77+
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import json
2+
import os
3+
from typing import Union
4+
5+
from optimum.quanto.models import QuantizedTransformersModel
6+
from optimum.quanto.models.shared_dict import ShardedStateDict
7+
from transformers import AutoConfig
8+
from transformers.modeling_utils import get_checkpoint_shard_files, load_state_dict
9+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, is_accelerate_available
10+
11+
from invokeai.backend.requantize import requantize
12+
13+
14+
class FastQuantizedTransformersModel(QuantizedTransformersModel):
15+
@classmethod
16+
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
17+
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
18+
if cls.auto_class is None:
19+
raise ValueError(
20+
"Quantized models cannot be reloaded using {cls}: use a specialized quantized class such as QuantizedModelForCausalLM instead."
21+
)
22+
if not is_accelerate_available():
23+
raise ValueError("Reloading a quantized transformers model requires the accelerate library.")
24+
from accelerate import init_empty_weights
25+
26+
if os.path.isdir(model_name_or_path):
27+
# Look for a quantization map
28+
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
29+
if not os.path.exists(qmap_path):
30+
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
31+
with open(qmap_path, "r", encoding="utf-8") as f:
32+
qmap = json.load(f)
33+
# Create an empty model
34+
config = AutoConfig.from_pretrained(model_name_or_path)
35+
with init_empty_weights():
36+
model = cls.auto_class.from_config(config)
37+
# Look for the index of a sharded checkpoint
38+
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
39+
if os.path.exists(checkpoint_file):
40+
# Convert the checkpoint path to a list of shards
41+
checkpoint_file, sharded_metadata = get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
42+
# Create a mapping for the sharded safetensor files
43+
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
44+
else:
45+
# Look for a single checkpoint file
46+
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_NAME)
47+
if not os.path.exists(checkpoint_file):
48+
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
49+
# Get state_dict from model checkpoint
50+
state_dict = load_state_dict(checkpoint_file)
51+
# Requantize and load quantized weights from state_dict
52+
requantize(model, state_dict=state_dict, quantization_map=qmap)
53+
if getattr(model.config, "tie_word_embeddings", True):
54+
# Tie output weight embeddings to input weight embeddings
55+
# Note that if they were quantized they would NOT be tied
56+
model.tie_weights()
57+
# Set model in evaluation mode as it is done in transformers
58+
model.eval()
59+
return cls(model)
60+
else:
61+
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")

0 commit comments

Comments
 (0)