1
- import json
2
1
from pathlib import Path
3
2
from typing import Literal
4
3
5
4
import torch
6
5
from diffusers import AutoencoderKL , FlowMatchEulerDiscreteScheduler
7
6
from diffusers .models .transformers .transformer_flux import FluxTransformer2DModel
8
7
from diffusers .pipelines .flux .pipeline_flux import FluxPipeline
9
- from optimum .quanto import freeze , qfloat8 , quantization_map , quantize , requantize
8
+ from optimum .quanto import qfloat8
9
+ from optimum .quanto .models import QuantizedDiffusersModel
10
10
from PIL import Image
11
- from safetensors .torch import load_file , save_file
12
11
from transformers import CLIPTextModel , CLIPTokenizer , T5EncoderModel , T5TokenizerFast
13
12
14
13
from invokeai .app .invocations .baseinvocation import BaseInvocation , invocation
21
20
FLUX_MODELS : dict [TFluxModelKeys , str ] = {"flux-schnell" : "black-forest-labs/FLUX.1-schnell" }
22
21
23
22
23
+ class QuantizedFluxTransformer2DModel (QuantizedDiffusersModel ):
24
+ base_class = FluxTransformer2DModel
25
+
26
+
24
27
@invocation (
25
28
"flux_text_to_image" ,
26
29
title = "FLUX Text to Image" ,
@@ -202,23 +205,16 @@ def _load_flux_text_encoder_2(path: Path) -> T5EncoderModel:
202
205
def _load_flux_transformer (self , path : Path ) -> FluxTransformer2DModel :
203
206
if self .use_8bit :
204
207
model_8bit_path = path / "quantized"
205
- model_8bit_weights_path = model_8bit_path / "weights.safetensors"
206
- model_8bit_map_path = model_8bit_path / "quantization_map.json"
207
208
if model_8bit_path .exists ():
208
209
# The quantized model exists, load it.
209
- # TODO(ryand): Make loading from quantized model work properly.
210
- # Reference: https://gist.github.com/AmericanPresidentJimmyCarter/873985638e1f3541ba8b00137e7dacd9?permalink_comment_id=5141210#gistcomment-5141210
211
- model = FluxTransformer2DModel .from_pretrained (
212
- path ,
213
- local_files_only = True ,
214
- )
215
- assert isinstance (model , FluxTransformer2DModel )
216
- model = model .to (device = torch .device ("meta" ))
217
-
218
- state_dict = load_file (model_8bit_weights_path )
219
- with open (model_8bit_map_path , "r" ) as f :
220
- quant_map = json .load (f )
221
- requantize (model = model , state_dict = state_dict , quantization_map = quant_map )
210
+ # TODO(ryand): The requantize(...) operation in from_pretrained(...) is very slow. This seems like
211
+ # something that we should be able to make much faster.
212
+ q_model = QuantizedFluxTransformer2DModel .from_pretrained (model_8bit_path )
213
+
214
+ # Access the underlying wrapped model.
215
+ # We access the wrapped model, even though it is private, because it simplifies the type checking by
216
+ # always returning a FluxTransformer2DModel from this function.
217
+ model = q_model ._wrapped
222
218
else :
223
219
# The quantized model does not exist yet, quantize and save it.
224
220
# TODO(ryand): Loading in float16 and then quantizing seems to result in NaNs. In order to run this on
@@ -227,13 +223,13 @@ def _load_flux_transformer(self, path: Path) -> FluxTransformer2DModel:
227
223
model = FluxTransformer2DModel .from_pretrained (path , local_files_only = True , torch_dtype = torch .bfloat16 )
228
224
assert isinstance (model , FluxTransformer2DModel )
229
225
230
- quantize (model , weights = qfloat8 )
231
- freeze (model )
226
+ q_model = QuantizedFluxTransformer2DModel .quantize (model , weights = qfloat8 )
232
227
233
228
model_8bit_path .mkdir (parents = True , exist_ok = True )
234
- save_file (model .state_dict (), model_8bit_weights_path )
235
- with open (model_8bit_map_path , "w" ) as f :
236
- json .dump (quantization_map (model ), f )
229
+ q_model .save_pretrained (model_8bit_path )
230
+
231
+ # (See earlier comment about accessing the wrapped model.)
232
+ model = q_model ._wrapped
237
233
else :
238
234
model = FluxTransformer2DModel .from_pretrained (path , local_files_only = True , torch_dtype = torch .bfloat16 )
239
235
0 commit comments