Skip to content

Commit 45f5af3

Browse files
committed
Add nf4 bnb quantized format
1 parent 9c4576c commit 45f5af3

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

invokeai/app/invocations/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,12 @@ def invoke(self, context: InvocationContext) -> ModelIdentifierOutput:
136136
T5_ENCODER_OPTIONS = Literal["base", "8b_quantized"]
137137
T5_ENCODER_MAP: Dict[str, Dict[str, str]] = {
138138
"base": {
139-
"repo": "invokeai/flux_dev::t5_xxl_encoder/base",
139+
"repo": "InvokeAI/flux_schnell::t5_xxl_encoder/base",
140140
"name": "t5_base_encoder",
141141
"format": ModelFormat.T5Encoder,
142142
},
143143
"8b_quantized": {
144-
"repo": "invokeai/flux_dev::t5_xxl_encoder/8b_quantized",
144+
"repo": "invokeai/flux_dev::t5_xxl_encoder/optimum_quanto_qfloat8",
145145
"name": "t5_8b_quantized_encoder",
146146
"format": ModelFormat.T5Encoder,
147147
},

invokeai/backend/model_manager/config.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class ModelFormat(str, Enum):
111111
T5Encoder = "t5_encoder"
112112
T5Encoder8b = "t5_encoder_8b"
113113
T5Encoder4b = "t5_encoder_4b"
114+
BnbQuantizednf4b = "bnb_quantized_nf4b"
114115

115116

116117
class SchedulerPredictionType(str, Enum):
@@ -193,7 +194,7 @@ def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> N
193194
class CheckpointConfigBase(ModelConfigBase):
194195
"""Model config for checkpoint-style models."""
195196

196-
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
197+
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint)
197198
config_path: str = Field(description="path to the checkpoint model config file")
198199
converted_at: Optional[float] = Field(
199200
description="When this model was last converted to diffusers", default_factory=time.time
@@ -248,7 +249,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
248249
"""Model config for standalone VAE models."""
249250

250251
type: Literal[ModelType.VAE] = ModelType.VAE
251-
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
252252

253253
@staticmethod
254254
def get_tag() -> Tag:
@@ -287,7 +287,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
287287
"""Model config for ControlNet models (diffusers version)."""
288288

289289
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
290-
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
291290

292291
@staticmethod
293292
def get_tag() -> Tag:
@@ -336,6 +335,21 @@ def get_tag() -> Tag:
336335
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
337336

338337

338+
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
339+
"""Model config for main checkpoint models."""
340+
341+
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
342+
upcast_attention: bool = False
343+
344+
def __init__(self, *args, **kwargs):
345+
super().__init__(*args, **kwargs)
346+
self.format = ModelFormat.BnbQuantizednf4b
347+
348+
@staticmethod
349+
def get_tag() -> Tag:
350+
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
351+
352+
339353
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
340354
"""Model config for main diffusers models."""
341355

@@ -438,6 +452,7 @@ def get_model_discriminator_value(v: Any) -> str:
438452
Union[
439453
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
440454
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
455+
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
441456
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
442457
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
443458
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],

invokeai/backend/model_manager/probe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def probe(
162162
fields["description"] = (
163163
fields.get("description") or f"{fields['base'].value} {model_type.value} model {fields['name']}"
164164
)
165-
fields["format"] = ModelFormat(fields.get("format")) or probe.get_format()
165+
fields["format"] = ModelFormat(fields.get("format")) if "format" in fields else probe.get_format()
166166
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
167167

168168
fields["default_settings"] = fields.get("default_settings")
@@ -179,7 +179,7 @@ def probe(
179179
# additional fields needed for main and controlnet models
180180
if (
181181
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
182-
and fields["format"] is ModelFormat.Checkpoint
182+
and fields["format"] in [ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b]
183183
):
184184
ckpt_config_path = cls._get_checkpoint_config_path(
185185
model_path,
@@ -323,6 +323,7 @@ def _get_checkpoint_config_path(
323323

324324
if model_type is ModelType.Main:
325325
if base_type == BaseModelType.Flux:
326+
# TODO: Decide between dev/schnell
326327
config_file = "flux/flux1-schnell.yaml"
327328
else:
328329
config_file = LEGACY_CONFIGS[base_type][variant_type]
@@ -422,6 +423,9 @@ def __init__(self, model_path: Path):
422423
self.checkpoint = ModelProbe._scan_and_load_checkpoint(model_path)
423424

424425
def get_format(self) -> ModelFormat:
426+
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
427+
if "double_blocks.0.img_attn.proj.weight.quant_state.bitsandbytes__nf4" in state_dict:
428+
return ModelFormat.BnbQuantizednf4b
425429
return ModelFormat("checkpoint")
426430

427431
def get_variant_type(self) -> ModelVariantType:

0 commit comments

Comments
 (0)