Skip to content

Commit cf57983

Browse files
authored
Merge pull request #16 from argmaxinc/benchmark_fix
Benchmark mode fix
2 parents 14ccbce + d1fbe7e commit cf57983

File tree

4 files changed

+110
-75
lines changed

4 files changed

+110
-75
lines changed

python/src/diffusionkit/mlx/__init__.py

Lines changed: 95 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -67,53 +67,77 @@ def __init__(
6767
self.dtype = self.float16_dtype if w16 else mx.float32
6868
self.activation_dtype = self.float16_dtype if a16 else mx.float32
6969
self.use_t5 = use_t5
70-
mmdit_ckpt = MMDIT_CKPT[model_version]
70+
self.mmdit_ckpt = MMDIT_CKPT[model_version]
7171
self.low_memory_mode = low_memory_mode
72-
self.mmdit = load_mmdit(
73-
float16=w16,
74-
key=mmdit_ckpt,
75-
model_key=model_version,
76-
low_memory_mode=low_memory_mode,
77-
)
72+
self.model = model
73+
self.model_version = model_version
7874
self.sampler = ModelSamplingDiscreteFlow(shift=shift)
79-
self.decoder = load_vae_decoder(float16=w16, key=mmdit_ckpt)
80-
self.encoder = load_vae_encoder(float16=False, key=mmdit_ckpt)
8175
self.latent_format = SD3LatentFormat()
76+
self.use_clip_g = True
77+
self.check_and_load_models()
8278

83-
self.clip_l = load_text_encoder(
84-
model,
85-
w16,
86-
model_key="clip_l",
87-
)
88-
self.tokenizer_l = load_tokenizer(
89-
model,
90-
merges_key="tokenizer_l_merges",
91-
vocab_key="tokenizer_l_vocab",
92-
pad_with_eos=True,
93-
)
94-
self.clip_g = load_text_encoder(
95-
model,
96-
w16,
97-
model_key="clip_g",
98-
)
99-
self.tokenizer_g = load_tokenizer(
100-
model,
101-
merges_key="tokenizer_g_merges",
102-
vocab_key="tokenizer_g_vocab",
103-
pad_with_eos=False,
104-
)
105-
self.t5_encoder = None
106-
self.t5_tokenizer = None
107-
if self.use_t5:
79+
def load_mmdit(self, only_modulation_dict=False):
80+
if only_modulation_dict:
81+
return load_mmdit(
82+
float16=True if self.dtype == self.float16_dtype else False,
83+
key=self.mmdit_ckpt,
84+
model_key=self.model_version,
85+
low_memory_mode=self.low_memory_mode,
86+
only_modulation_dict=only_modulation_dict,
87+
)
88+
self.mmdit = load_mmdit(
89+
float16=True if self.dtype == self.float16_dtype else False,
90+
key=self.mmdit_ckpt,
91+
model_key=self.model_version,
92+
low_memory_mode=self.low_memory_mode,
93+
only_modulation_dict=only_modulation_dict,
94+
)
95+
96+
def check_and_load_models(self):
97+
if not hasattr(self, "mmdit"):
98+
self.load_mmdit()
99+
if not hasattr(self, "decoder"):
100+
self.decoder = load_vae_decoder(
101+
float16=True if self.dtype == self.float16_dtype else False,
102+
key=self.mmdit_ckpt,
103+
)
104+
if not hasattr(self, "encoder"):
105+
self.encoder = load_vae_encoder(float16=False, key=self.mmdit_ckpt)
106+
107+
if not hasattr(self, "clip_l"):
108+
self.clip_l = load_text_encoder(
109+
self.model,
110+
float16=True if self.dtype == self.float16_dtype else False,
111+
model_key="clip_l",
112+
)
113+
self.tokenizer_l = load_tokenizer(
114+
self.model,
115+
merges_key="tokenizer_l_merges",
116+
vocab_key="tokenizer_l_vocab",
117+
pad_with_eos=True,
118+
)
119+
if self.use_clip_g and not hasattr(self, "clip_g"):
120+
self.clip_g = load_text_encoder(
121+
self.model,
122+
float16=True if self.dtype == self.float16_dtype else False,
123+
model_key="clip_g",
124+
)
125+
self.tokenizer_g = load_tokenizer(
126+
self.model,
127+
merges_key="tokenizer_g_merges",
128+
vocab_key="tokenizer_g_vocab",
129+
pad_with_eos=False,
130+
)
131+
if self.use_t5 and not hasattr(self, "t5_encoder"):
108132
self.set_up_t5()
109133

110134
def set_up_t5(self):
111-
if self.t5_encoder is None:
135+
if not hasattr(self, "t5_encoder") or self.t5_encoder is None:
112136
self.t5_encoder = load_t5_encoder(
113137
float16=True if self.dtype == self.float16_dtype else False,
114138
low_memory_mode=self.low_memory_mode,
115139
)
116-
if self.t5_tokenizer is None:
140+
if not hasattr(self, "t5_tokenizer") or self.t5_tokenizer is None:
117141
self.t5_tokenizer = load_t5_tokenizer()
118142
self.use_t5 = True
119143

@@ -266,6 +290,7 @@ def generate_image(
266290
image_path: Optional[str] = None,
267291
denoise: float = 1.0,
268292
):
293+
self.check_and_load_models()
269294
# Start timing
270295
start_time = time.time()
271296

@@ -405,10 +430,9 @@ def generate_image(
405430
)
406431
logger.info(f"Denoising time: {log['denoising']['time']}s")
407432

408-
# unload MMDIT and Sampler models after obtaining latents in low memory mode
433+
# unload MMDIT model after obtaining latents in low memory mode
409434
if self.low_memory_mode:
410435
del self.mmdit
411-
del self.sampler
412436
gc.collect()
413437

414438
logger.debug(f"Latents dtype before casting: {latents.dtype}")
@@ -458,18 +482,19 @@ def generate_image(
458482
f"Post decode active memory: {log['decoding']['post']['active_memory']}GB"
459483
)
460484

461-
logger.info("============= Summary =============")
462-
logger.info(f"Text encoder: {log['text_encoding']['time']:.1f}s")
463-
logger.info(f"Denoising: {log['denoising']['time']:.1f}s")
464-
logger.info(f"Image decoder: {log['decoding']['time']:.1f}s")
465-
logger.info(f"Peak memory: {log['peak_memory']:.1f}GB")
466-
467-
logger.info("============= Inference Context =============")
468-
ic = DiffusionKitInferenceContext()
469-
logger.info("Operating System:")
470-
pprint(ic.os_spec())
471-
logger.info("Device:")
472-
pprint(ic.device_spec())
485+
if verbose:
486+
logger.info("============= Summary =============")
487+
logger.info(f"Text encoder: {log['text_encoding']['time']:.1f}s")
488+
logger.info(f"Denoising: {log['denoising']['time']:.1f}s")
489+
logger.info(f"Image decoder: {log['decoding']['time']:.1f}s")
490+
logger.info(f"Peak memory: {log['peak_memory']:.1f}GB")
491+
492+
logger.info("============= Inference Context =============")
493+
ic = DiffusionKitInferenceContext()
494+
logger.info("Operating System:")
495+
pprint(ic.os_spec())
496+
logger.info("Device:")
497+
pprint(ic.device_spec())
473498

474499
# unload VAE Decoder model after decoding in low memory mode
475500
if self.low_memory_mode:
@@ -566,30 +591,28 @@ def __init__(
566591
model_io._FLOAT16 = self.float16_dtype
567592
self.dtype = self.float16_dtype if w16 else mx.float32
568593
self.activation_dtype = self.float16_dtype if a16 else mx.float32
569-
mmdit_ckpt = MMDIT_CKPT[model_version]
594+
self.mmdit_ckpt = MMDIT_CKPT[model_version]
570595
self.low_memory_mode = low_memory_mode
571-
self.mmdit = load_flux(float16=w16, low_memory_mode=low_memory_mode)
596+
self.model = model
597+
self.model_version = model_version
572598
self.sampler = FluxSampler(shift=shift)
573-
self.decoder = load_vae_decoder(float16=w16, key=mmdit_ckpt)
574-
self.encoder = load_vae_encoder(float16=False, key=mmdit_ckpt)
575599
self.latent_format = FluxLatentFormat()
576600
self.use_t5 = True
601+
self.use_clip_g = False
602+
self.check_and_load_models()
577603

578-
self.clip_l = load_text_encoder(
579-
model,
580-
w16,
581-
model_key="clip_l",
582-
)
583-
self.tokenizer_l = load_tokenizer(
584-
model,
585-
merges_key="tokenizer_l_merges",
586-
vocab_key="tokenizer_l_vocab",
587-
pad_with_eos=True,
604+
def load_mmdit(self, only_modulation_dict=False):
605+
if only_modulation_dict:
606+
return load_flux(
607+
float16=True if self.dtype == self.float16_dtype else False,
608+
low_memory_mode=self.low_memory_mode,
609+
only_modulation_dict=only_modulation_dict,
610+
)
611+
self.mmdit = load_flux(
612+
float16=True if self.dtype == self.float16_dtype else False,
613+
low_memory_mode=self.low_memory_mode,
614+
only_modulation_dict=only_modulation_dict,
588615
)
589-
self.t5_encoder = None
590-
self.t5_tokenizer = None
591-
if self.use_t5:
592-
self.set_up_t5()
593616

594617
def encode_text(
595618
self,
@@ -634,7 +657,9 @@ def cache_modulation_params(self, pooled_text_embeddings, sigmas):
634657
)
635658

636659
def clear_cache(self):
637-
self.model.mmdit.clear_modulation_params_cache()
660+
self.model.mmdit.load_weights(
661+
self.model.load_mmdit(only_modulation_dict=True), strict=False
662+
)
638663

639664
def __call__(
640665
self,
@@ -731,6 +756,6 @@ def sample_euler(model: CFGDenoiser, x, sigmas, extra_args=None):
731756
end_time = t.format_dict["elapsed"]
732757
iter_time.append(round((end_time - start_time), 3))
733758

734-
# model.clear_cache()
759+
model.clear_cache()
735760

736761
return x, iter_time

python/src/diffusionkit/mlx/mmdit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
from argmaxtools.utils import get_logger
1212
from beartype.typing import Dict, List, Optional, Tuple
13+
from mlx.utils import tree_map
1314

1415
from .config import MMDiTConfig, PositionalEncoding
1516

@@ -163,7 +164,8 @@ def cache_modulation_params(
163164

164165
self.to_offload = to_offload
165166
for x in self.to_offload:
166-
x.clear()
167+
x.update(tree_map(lambda _: mx.array([]), x.parameters()))
168+
# x.clear()
167169

168170
logger.info(f"Cached modulation_params for timesteps={timesteps}")
169171
logger.info(

python/src/diffusionkit/mlx/model_io.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,8 +654,10 @@ def load_mmdit(
654654
float16: bool = False,
655655
model_key: str = "mmdit_2b",
656656
low_memory_mode: bool = True,
657+
only_modulation_dict: bool = False,
657658
):
658659
"""Load the MM-DiT model from the checkpoint file."""
660+
"""only_modulation_dict: Only returns the modulation dictionary"""
659661
dtype = _FLOAT16 if float16 else mx.float32
660662
config = SD3_2b
661663
config.low_memory_mode = low_memory_mode
@@ -666,6 +668,9 @@ def load_mmdit(
666668
weights = mx.load(mmdit_weights_ckpt)
667669
weights = mmdit_state_dict_adjustments(weights, prefix="model.diffusion_model.")
668670
weights = {k: v.astype(dtype) for k, v in weights.items()}
671+
if only_modulation_dict:
672+
weights = {k: v for k, v in weights.items() if "adaLN" in k}
673+
return tree_flatten(weights)
669674
model.update(tree_unflatten(tree_flatten(weights)))
670675

671676
return model
@@ -676,6 +681,7 @@ def load_flux(
676681
float16: bool = False,
677682
model_key: str = "FLUX.1-schnell",
678683
low_memory_mode: bool = True,
684+
only_modulation_dict: bool = False,
679685
):
680686
"""Load the MM-DiT Flux model from the checkpoint file."""
681687
dtype = _FLOAT16 if float16 else mx.float32
@@ -691,6 +697,9 @@ def load_flux(
691697
weights, prefix="", hidden_size=config.hidden_size, mlp_ratio=config.mlp_ratio
692698
)
693699
weights = {k: v.astype(dtype) for k, v in weights.items()}
700+
if only_modulation_dict:
701+
weights = {k: v for k, v in weights.items() if "adaLN" in k}
702+
return tree_flatten(weights)
694703
model.update(tree_unflatten(tree_flatten(weights)))
695704

696705
return model

python/src/diffusionkit/mlx/scripts/generate_images.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,9 @@ def cli():
112112
args.cfg = 0.0
113113

114114
if args.benchmark_mode:
115-
raise NotImplementedError
116-
# if args.low_memory_mode:
117-
# logger.warning("Benchmark mode is enabled, disabling low memory mode.")
118-
# args.low_memory_mode = False
115+
if args.low_memory_mode:
116+
logger.warning("Benchmark mode is enabled, disabling low memory mode.")
117+
args.low_memory_mode = False
119118

120119
if args.denoise < 0.0 or args.denoise > 1.0:
121120
raise ValueError("Denoising factor must be between 0.0 and 1.0")

0 commit comments

Comments
 (0)