@@ -111,6 +111,7 @@ class ModelFormat(str, Enum):
111
111
T5Encoder = "t5_encoder"
112
112
T5Encoder8b = "t5_encoder_8b"
113
113
T5Encoder4b = "t5_encoder_4b"
114
+ BnbQuantizednf4b = "bnb_quantized_nf4b"
114
115
115
116
116
117
class SchedulerPredictionType (str , Enum ):
@@ -193,7 +194,7 @@ def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> N
193
194
class CheckpointConfigBase (ModelConfigBase ):
194
195
"""Model config for checkpoint-style models."""
195
196
196
- format : Literal [ModelFormat .Checkpoint ] = ModelFormat .Checkpoint
197
+ format : Literal [ModelFormat .Checkpoint , ModelFormat . BnbQuantizednf4b ] = Field ( description = "Format of the provided checkpoint model" , default = ModelFormat .Checkpoint )
197
198
config_path : str = Field (description = "path to the checkpoint model config file" )
198
199
converted_at : Optional [float ] = Field (
199
200
description = "When this model was last converted to diffusers" , default_factory = time .time
@@ -248,7 +249,6 @@ class VAECheckpointConfig(CheckpointConfigBase):
248
249
"""Model config for standalone VAE models."""
249
250
250
251
type : Literal [ModelType .VAE ] = ModelType .VAE
251
- format : Literal [ModelFormat .Checkpoint ] = ModelFormat .Checkpoint
252
252
253
253
@staticmethod
254
254
def get_tag () -> Tag :
@@ -287,7 +287,6 @@ class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase)
287
287
"""Model config for ControlNet models (diffusers version)."""
288
288
289
289
type : Literal [ModelType .ControlNet ] = ModelType .ControlNet
290
- format : Literal [ModelFormat .Checkpoint ] = ModelFormat .Checkpoint
291
290
292
291
@staticmethod
293
292
def get_tag () -> Tag :
@@ -336,6 +335,21 @@ def get_tag() -> Tag:
336
335
return Tag (f"{ ModelType .Main .value } .{ ModelFormat .Checkpoint .value } " )
337
336
338
337
338
+ class MainBnbQuantized4bCheckpointConfig (CheckpointConfigBase , MainConfigBase ):
339
+ """Model config for main checkpoint models."""
340
+
341
+ prediction_type : SchedulerPredictionType = SchedulerPredictionType .Epsilon
342
+ upcast_attention : bool = False
343
+
344
+ def __init__ (self , * args , ** kwargs ):
345
+ super ().__init__ (* args , ** kwargs )
346
+ self .format = ModelFormat .BnbQuantizednf4b
347
+
348
+ @staticmethod
349
+ def get_tag () -> Tag :
350
+ return Tag (f"{ ModelType .Main .value } .{ ModelFormat .BnbQuantizednf4b .value } " )
351
+
352
+
339
353
class MainDiffusersConfig (DiffusersConfigBase , MainConfigBase ):
340
354
"""Model config for main diffusers models."""
341
355
@@ -438,6 +452,7 @@ def get_model_discriminator_value(v: Any) -> str:
438
452
Union [
439
453
Annotated [MainDiffusersConfig , MainDiffusersConfig .get_tag ()],
440
454
Annotated [MainCheckpointConfig , MainCheckpointConfig .get_tag ()],
455
+ Annotated [MainBnbQuantized4bCheckpointConfig , MainBnbQuantized4bCheckpointConfig .get_tag ()],
441
456
Annotated [VAEDiffusersConfig , VAEDiffusersConfig .get_tag ()],
442
457
Annotated [VAECheckpointConfig , VAECheckpointConfig .get_tag ()],
443
458
Annotated [ControlNetDiffusersConfig , ControlNetDiffusersConfig .get_tag ()],
0 commit comments