Skip to content

Commit 4181ab6

Browse files
committed
WIP - experimentation
1 parent 1c97360 commit 4181ab6

File tree

2 files changed

+183
-0
lines changed

2 files changed

+183
-0
lines changed

invokeai/backend/load_flux_model.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import json
2+
import os
3+
import time
4+
from pathlib import Path
5+
from typing import Union
6+
7+
import torch
8+
from diffusers.models.model_loading_utils import load_state_dict
9+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
10+
from diffusers.utils import (
11+
CONFIG_NAME,
12+
SAFE_WEIGHTS_INDEX_NAME,
13+
SAFETENSORS_WEIGHTS_NAME,
14+
_get_checkpoint_shard_files,
15+
is_accelerate_available,
16+
)
17+
from optimum.quanto import qfloat8
18+
from optimum.quanto.models import QuantizedDiffusersModel
19+
from optimum.quanto.models.shared_dict import ShardedStateDict
20+
21+
from invokeai.backend.requantize import requantize
22+
23+
24+
class QuantizedFluxTransformer2DModel(QuantizedDiffusersModel):
25+
base_class = FluxTransformer2DModel
26+
27+
@classmethod
28+
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike]):
29+
if cls.base_class is None:
30+
raise ValueError("The `base_class` attribute needs to be configured.")
31+
32+
if not is_accelerate_available():
33+
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.")
34+
from accelerate import init_empty_weights
35+
36+
if os.path.isdir(model_name_or_path):
37+
# Look for a quantization map
38+
qmap_path = os.path.join(model_name_or_path, cls._qmap_name())
39+
if not os.path.exists(qmap_path):
40+
raise ValueError(f"No quantization map found in {model_name_or_path}: is this a quantized model ?")
41+
42+
# Look for original model config file.
43+
model_config_path = os.path.join(model_name_or_path, CONFIG_NAME)
44+
if not os.path.exists(model_config_path):
45+
raise ValueError(f"{CONFIG_NAME} not found in {model_name_or_path}.")
46+
47+
with open(qmap_path, "r", encoding="utf-8") as f:
48+
qmap = json.load(f)
49+
50+
with open(model_config_path, "r", encoding="utf-8") as f:
51+
original_model_cls_name = json.load(f)["_class_name"]
52+
configured_cls_name = cls.base_class.__name__
53+
if configured_cls_name != original_model_cls_name:
54+
raise ValueError(
55+
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})."
56+
)
57+
58+
# Create an empty model
59+
config = cls.base_class.load_config(model_name_or_path)
60+
with init_empty_weights():
61+
model = cls.base_class.from_config(config)
62+
63+
# Look for the index of a sharded checkpoint
64+
checkpoint_file = os.path.join(model_name_or_path, SAFE_WEIGHTS_INDEX_NAME)
65+
if os.path.exists(checkpoint_file):
66+
# Convert the checkpoint path to a list of shards
67+
_, sharded_metadata = _get_checkpoint_shard_files(model_name_or_path, checkpoint_file)
68+
# Create a mapping for the sharded safetensor files
69+
state_dict = ShardedStateDict(model_name_or_path, sharded_metadata["weight_map"])
70+
else:
71+
# Look for a single checkpoint file
72+
checkpoint_file = os.path.join(model_name_or_path, SAFETENSORS_WEIGHTS_NAME)
73+
if not os.path.exists(checkpoint_file):
74+
raise ValueError(f"No safetensor weights found in {model_name_or_path}.")
75+
# Get state_dict from model checkpoint
76+
state_dict = load_state_dict(checkpoint_file)
77+
78+
# Requantize and load quantized weights from state_dict
79+
requantize(model, state_dict=state_dict, quantization_map=qmap)
80+
model.eval()
81+
return cls(model)
82+
else:
83+
raise NotImplementedError("Reloading quantized models directly from the hub is not supported yet.")
84+
85+
86+
def load_flux_transformer(path: Path) -> FluxTransformer2DModel:
87+
# model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
88+
model_8bit_path = path / "quantized"
89+
if model_8bit_path.exists():
90+
# The quantized model exists, load it.
91+
# TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
92+
# something that we should be able to make much faster.
93+
q_model = QuantizedFluxTransformer2DModel.from_pretrained(model_8bit_path)
94+
95+
# Access the underlying wrapped model.
96+
# We access the wrapped model, even though it is private, because it simplifies the type checking by
97+
# always returning a FluxTransformer2DModel from this function.
98+
model = q_model._wrapped
99+
else:
100+
# The quantized model does not exist yet, quantize and save it.
101+
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
102+
# GPUs that don't support bfloat16, we would need to host the quantized model instead of generating it
103+
# here.
104+
model = FluxTransformer2DModel.from_pretrained(path, local_files_only=True, torch_dtype=torch.bfloat16)
105+
assert isinstance(model, FluxTransformer2DModel)
106+
107+
q_model = QuantizedFluxTransformer2DModel.quantize(model, weights=qfloat8)
108+
109+
model_8bit_path.mkdir(parents=True, exist_ok=True)
110+
q_model.save_pretrained(model_8bit_path)
111+
112+
# (See earlier comment about accessing the wrapped model.)
113+
model = q_model._wrapped
114+
115+
assert isinstance(model, FluxTransformer2DModel)
116+
return model
117+
118+
119+
def main():
120+
start = time.time()
121+
model = load_flux_transformer(
122+
Path("/data/invokeai/models/.download_cache/black-forest-labs_flux.1-schnell/FLUX.1-schnell/transformer/")
123+
)
124+
print(f"Time to load: {time.time() - start}s")
125+
print("hi")
126+
127+
128+
if __name__ == "__main__":
129+
main()

invokeai/backend/requantize.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Any, Dict
2+
3+
import torch
4+
from optimum.quanto.nn import QModuleMixin
5+
from optimum.quanto.quantize import _quantize_submodule, freeze
6+
7+
8+
def custom_freeze(model: torch.nn.Module):
9+
for name, m in model.named_modules():
10+
if isinstance(m, QModuleMixin):
11+
m.freeze()
12+
13+
14+
def requantize(
15+
model: torch.nn.Module,
16+
state_dict: Dict[str, Any],
17+
quantization_map: Dict[str, Dict[str, str]],
18+
device: torch.device = None,
19+
):
20+
if device is None:
21+
device = next(model.parameters()).device
22+
if device.type == "meta":
23+
device = torch.device("cpu")
24+
25+
# Quantize the model with parameters from the quantization map
26+
for name, m in model.named_modules():
27+
qconfig = quantization_map.get(name, None)
28+
if qconfig is not None:
29+
weights = qconfig["weights"]
30+
if weights == "none":
31+
weights = None
32+
activations = qconfig["activations"]
33+
if activations == "none":
34+
activations = None
35+
_quantize_submodule(model, name, m, weights=weights, activations=activations)
36+
37+
# Move model parameters and buffers to CPU before materializing quantized weights
38+
for name, m in model.named_modules():
39+
40+
def move_tensor(t, device):
41+
if t.device.type == "meta":
42+
return torch.empty_like(t, device=device)
43+
return t.to(device)
44+
45+
for name, param in m.named_parameters(recurse=False):
46+
setattr(m, name, torch.nn.Parameter(move_tensor(param, "cpu")))
47+
for name, param in m.named_buffers(recurse=False):
48+
setattr(m, name, move_tensor(param, "cpu"))
49+
# Freeze model and move to target device
50+
freeze(model)
51+
model.to(device)
52+
53+
# Load the quantized model weights
54+
model.load_state_dict(state_dict, strict=False)

0 commit comments

Comments
 (0)