@@ -67,53 +67,77 @@ def __init__(
67
67
self .dtype = self .float16_dtype if w16 else mx .float32
68
68
self .activation_dtype = self .float16_dtype if a16 else mx .float32
69
69
self .use_t5 = use_t5
70
- mmdit_ckpt = MMDIT_CKPT [model_version ]
70
+ self . mmdit_ckpt = MMDIT_CKPT [model_version ]
71
71
self .low_memory_mode = low_memory_mode
72
- self .mmdit = load_mmdit (
73
- float16 = w16 ,
74
- key = mmdit_ckpt ,
75
- model_key = model_version ,
76
- low_memory_mode = low_memory_mode ,
77
- )
72
+ self .model = model
73
+ self .model_version = model_version
78
74
self .sampler = ModelSamplingDiscreteFlow (shift = shift )
79
- self .decoder = load_vae_decoder (float16 = w16 , key = mmdit_ckpt )
80
- self .encoder = load_vae_encoder (float16 = False , key = mmdit_ckpt )
81
75
self .latent_format = SD3LatentFormat ()
76
+ self .use_clip_g = True
77
+ self .check_and_load_models ()
82
78
83
- self .clip_l = load_text_encoder (
84
- model ,
85
- w16 ,
86
- model_key = "clip_l" ,
87
- )
88
- self .tokenizer_l = load_tokenizer (
89
- model ,
90
- merges_key = "tokenizer_l_merges" ,
91
- vocab_key = "tokenizer_l_vocab" ,
92
- pad_with_eos = True ,
93
- )
94
- self .clip_g = load_text_encoder (
95
- model ,
96
- w16 ,
97
- model_key = "clip_g" ,
98
- )
99
- self .tokenizer_g = load_tokenizer (
100
- model ,
101
- merges_key = "tokenizer_g_merges" ,
102
- vocab_key = "tokenizer_g_vocab" ,
103
- pad_with_eos = False ,
104
- )
105
- self .t5_encoder = None
106
- self .t5_tokenizer = None
107
- if self .use_t5 :
79
+ def load_mmdit (self , only_modulation_dict = False ):
80
+ if only_modulation_dict :
81
+ return load_mmdit (
82
+ float16 = True if self .dtype == self .float16_dtype else False ,
83
+ key = self .mmdit_ckpt ,
84
+ model_key = self .model_version ,
85
+ low_memory_mode = self .low_memory_mode ,
86
+ only_modulation_dict = only_modulation_dict ,
87
+ )
88
+ self .mmdit = load_mmdit (
89
+ float16 = True if self .dtype == self .float16_dtype else False ,
90
+ key = self .mmdit_ckpt ,
91
+ model_key = self .model_version ,
92
+ low_memory_mode = self .low_memory_mode ,
93
+ only_modulation_dict = only_modulation_dict ,
94
+ )
95
+
96
+ def check_and_load_models (self ):
97
+ if not hasattr (self , "mmdit" ):
98
+ self .load_mmdit ()
99
+ if not hasattr (self , "decoder" ):
100
+ self .decoder = load_vae_decoder (
101
+ float16 = True if self .dtype == self .float16_dtype else False ,
102
+ key = self .mmdit_ckpt ,
103
+ )
104
+ if not hasattr (self , "encoder" ):
105
+ self .encoder = load_vae_encoder (float16 = False , key = self .mmdit_ckpt )
106
+
107
+ if not hasattr (self , "clip_l" ):
108
+ self .clip_l = load_text_encoder (
109
+ self .model ,
110
+ float16 = True if self .dtype == self .float16_dtype else False ,
111
+ model_key = "clip_l" ,
112
+ )
113
+ self .tokenizer_l = load_tokenizer (
114
+ self .model ,
115
+ merges_key = "tokenizer_l_merges" ,
116
+ vocab_key = "tokenizer_l_vocab" ,
117
+ pad_with_eos = True ,
118
+ )
119
+ if self .use_clip_g and not hasattr (self , "clip_g" ):
120
+ self .clip_g = load_text_encoder (
121
+ self .model ,
122
+ float16 = True if self .dtype == self .float16_dtype else False ,
123
+ model_key = "clip_g" ,
124
+ )
125
+ self .tokenizer_g = load_tokenizer (
126
+ self .model ,
127
+ merges_key = "tokenizer_g_merges" ,
128
+ vocab_key = "tokenizer_g_vocab" ,
129
+ pad_with_eos = False ,
130
+ )
131
+ if self .use_t5 and not hasattr (self , "t5_encoder" ):
108
132
self .set_up_t5 ()
109
133
110
134
def set_up_t5 (self ):
111
- if self .t5_encoder is None :
135
+ if not hasattr ( self , "t5_encoder" ) or self .t5_encoder is None :
112
136
self .t5_encoder = load_t5_encoder (
113
137
float16 = True if self .dtype == self .float16_dtype else False ,
114
138
low_memory_mode = self .low_memory_mode ,
115
139
)
116
- if self .t5_tokenizer is None :
140
+ if not hasattr ( self , "t5_tokenizer" ) or self .t5_tokenizer is None :
117
141
self .t5_tokenizer = load_t5_tokenizer ()
118
142
self .use_t5 = True
119
143
@@ -266,6 +290,7 @@ def generate_image(
266
290
image_path : Optional [str ] = None ,
267
291
denoise : float = 1.0 ,
268
292
):
293
+ self .check_and_load_models ()
269
294
# Start timing
270
295
start_time = time .time ()
271
296
@@ -405,10 +430,9 @@ def generate_image(
405
430
)
406
431
logger .info (f"Denoising time: { log ['denoising' ]['time' ]} s" )
407
432
408
- # unload MMDIT and Sampler models after obtaining latents in low memory mode
433
+ # unload MMDIT model after obtaining latents in low memory mode
409
434
if self .low_memory_mode :
410
435
del self .mmdit
411
- del self .sampler
412
436
gc .collect ()
413
437
414
438
logger .debug (f"Latents dtype before casting: { latents .dtype } " )
@@ -458,18 +482,19 @@ def generate_image(
458
482
f"Post decode active memory: { log ['decoding' ]['post' ]['active_memory' ]} GB"
459
483
)
460
484
461
- logger .info ("============= Summary =============" )
462
- logger .info (f"Text encoder: { log ['text_encoding' ]['time' ]:.1f} s" )
463
- logger .info (f"Denoising: { log ['denoising' ]['time' ]:.1f} s" )
464
- logger .info (f"Image decoder: { log ['decoding' ]['time' ]:.1f} s" )
465
- logger .info (f"Peak memory: { log ['peak_memory' ]:.1f} GB" )
466
-
467
- logger .info ("============= Inference Context =============" )
468
- ic = DiffusionKitInferenceContext ()
469
- logger .info ("Operating System:" )
470
- pprint (ic .os_spec ())
471
- logger .info ("Device:" )
472
- pprint (ic .device_spec ())
485
+ if verbose :
486
+ logger .info ("============= Summary =============" )
487
+ logger .info (f"Text encoder: { log ['text_encoding' ]['time' ]:.1f} s" )
488
+ logger .info (f"Denoising: { log ['denoising' ]['time' ]:.1f} s" )
489
+ logger .info (f"Image decoder: { log ['decoding' ]['time' ]:.1f} s" )
490
+ logger .info (f"Peak memory: { log ['peak_memory' ]:.1f} GB" )
491
+
492
+ logger .info ("============= Inference Context =============" )
493
+ ic = DiffusionKitInferenceContext ()
494
+ logger .info ("Operating System:" )
495
+ pprint (ic .os_spec ())
496
+ logger .info ("Device:" )
497
+ pprint (ic .device_spec ())
473
498
474
499
# unload VAE Decoder model after decoding in low memory mode
475
500
if self .low_memory_mode :
@@ -566,30 +591,28 @@ def __init__(
566
591
model_io ._FLOAT16 = self .float16_dtype
567
592
self .dtype = self .float16_dtype if w16 else mx .float32
568
593
self .activation_dtype = self .float16_dtype if a16 else mx .float32
569
- mmdit_ckpt = MMDIT_CKPT [model_version ]
594
+ self . mmdit_ckpt = MMDIT_CKPT [model_version ]
570
595
self .low_memory_mode = low_memory_mode
571
- self .mmdit = load_flux (float16 = w16 , low_memory_mode = low_memory_mode )
596
+ self .model = model
597
+ self .model_version = model_version
572
598
self .sampler = FluxSampler (shift = shift )
573
- self .decoder = load_vae_decoder (float16 = w16 , key = mmdit_ckpt )
574
- self .encoder = load_vae_encoder (float16 = False , key = mmdit_ckpt )
575
599
self .latent_format = FluxLatentFormat ()
576
600
self .use_t5 = True
601
+ self .use_clip_g = False
602
+ self .check_and_load_models ()
577
603
578
- self .clip_l = load_text_encoder (
579
- model ,
580
- w16 ,
581
- model_key = "clip_l" ,
582
- )
583
- self .tokenizer_l = load_tokenizer (
584
- model ,
585
- merges_key = "tokenizer_l_merges" ,
586
- vocab_key = "tokenizer_l_vocab" ,
587
- pad_with_eos = True ,
604
+ def load_mmdit (self , only_modulation_dict = False ):
605
+ if only_modulation_dict :
606
+ return load_flux (
607
+ float16 = True if self .dtype == self .float16_dtype else False ,
608
+ low_memory_mode = self .low_memory_mode ,
609
+ only_modulation_dict = only_modulation_dict ,
610
+ )
611
+ self .mmdit = load_flux (
612
+ float16 = True if self .dtype == self .float16_dtype else False ,
613
+ low_memory_mode = self .low_memory_mode ,
614
+ only_modulation_dict = only_modulation_dict ,
588
615
)
589
- self .t5_encoder = None
590
- self .t5_tokenizer = None
591
- if self .use_t5 :
592
- self .set_up_t5 ()
593
616
594
617
def encode_text (
595
618
self ,
@@ -634,7 +657,9 @@ def cache_modulation_params(self, pooled_text_embeddings, sigmas):
634
657
)
635
658
636
659
def clear_cache (self ):
637
- self .model .mmdit .clear_modulation_params_cache ()
660
+ self .model .mmdit .load_weights (
661
+ self .model .load_mmdit (only_modulation_dict = True ), strict = False
662
+ )
638
663
639
664
def __call__ (
640
665
self ,
@@ -731,6 +756,6 @@ def sample_euler(model: CFGDenoiser, x, sigmas, extra_args=None):
731
756
end_time = t .format_dict ["elapsed" ]
732
757
iter_time .append (round ((end_time - start_time ), 3 ))
733
758
734
- # model.clear_cache()
759
+ model .clear_cache ()
735
760
736
761
return x , iter_time
0 commit comments