8
8
import gc
9
9
import math
10
10
import time
11
+ from pprint import pprint
11
12
from typing import Optional , Tuple
12
13
13
14
import mlx .core as mx
14
15
import mlx .nn as nn
15
16
import numpy as np
17
+ from argmaxtools .test_utils import AppleSiliconContextMixin , InferenceContextSpec
16
18
from argmaxtools .utils import get_logger
17
19
from diffusionkit .utils import bytes2gigabytes
18
20
from PIL import Image
39
41
}
40
42
41
43
44
+ class DiffusionKitInferenceContext (AppleSiliconContextMixin , InferenceContextSpec ):
45
+ def code_spec (self ):
46
+ return {}
47
+
48
+ def model_spec (self ):
49
+ return {}
50
+
51
+
42
52
class DiffusionPipeline :
43
53
def __init__ (
44
54
self ,
@@ -292,6 +302,11 @@ def generate_image(
292
302
logger .info (
293
303
f"Pre text encoding active memory: { log ['text_encoding' ]['pre' ]['active_memory' ]} GB"
294
304
)
305
+
306
+ # FIXME(arda): Need the same for CLIP models (low memory mode will not succeed a second time otherwise)
307
+ if not hasattr (self , "t5" ):
308
+ self .set_up_t5 ()
309
+
295
310
conditioning , pooled_conditioning = self .encode_text (
296
311
text , cfg_weight , negative_text
297
312
)
@@ -442,8 +457,19 @@ def generate_image(
442
457
logger .info (
443
458
f"Post decode active memory: { log ['decoding' ]['post' ]['active_memory' ]} GB"
444
459
)
445
- logger .info (f"Decoding time: { log ['decoding' ]['time' ]} s" )
446
- logger .info (f"Peak memory: { log ['peak_memory' ]} GB" )
460
+
461
+ logger .info ("============= Summary =============" )
462
+ logger .info (f"Text encoder: { log ['text_encoding' ]['time' ]:.1f} s" )
463
+ logger .info (f"Denoising: { log ['denoising' ]['time' ]:.1f} s" )
464
+ logger .info (f"Image decoder: { log ['decoding' ]['time' ]:.1f} s" )
465
+ logger .info (f"Peak memory: { log ['peak_memory' ]:.1f} GB" )
466
+
467
+ logger .info ("============= Inference Context =============" )
468
+ ic = DiffusionKitInferenceContext ()
469
+ logger .info ("Operating System:" )
470
+ pprint (ic .os_spec ())
471
+ logger .info ("Device:" )
472
+ pprint (ic .device_spec ())
447
473
448
474
# unload VAE Decoder model after decoding in low memory mode
449
475
if self .low_memory_mode :
@@ -462,24 +488,14 @@ def generate_image(
462
488
463
489
return Image .fromarray (np .array (x )), log
464
490
465
- def generate_ids (self , latent_size : Tuple [int ]):
466
- h , w = latent_size
467
- img_ids = mx .zeros ((h // 2 , w // 2 , 3 ))
468
- img_ids [..., 1 ] = img_ids [..., 1 ] + mx .arange (h // 2 )[:, None ]
469
- img_ids [..., 2 ] = img_ids [..., 2 ] + mx .arange (w // 2 )[None , :]
470
- img_ids = img_ids .reshape (1 , - 1 , 3 )
471
-
472
- txt_ids = mx .zeros ((1 , 256 , 3 )) # Hardcoded to context length of T5
473
- return img_ids , txt_ids
474
-
475
491
def read_image (self , image_path : str ):
476
492
# Read the image
477
493
img = Image .open (image_path )
478
494
479
495
# Make sure image shape is divisible by 64
480
496
W , H = (dim - dim % 64 for dim in (img .width , img .height ))
481
497
if W != img .width or H != img .height :
482
- print (
498
+ logger . warning (
483
499
f"Warning: image shape is not divisible by 64, downsampling to { W } x{ H } "
484
500
)
485
501
img = img .resize ((W , H ), Image .LANCZOS ) # use desired downsampling filter
@@ -557,9 +573,6 @@ def __init__(
557
573
self .decoder = load_vae_decoder (float16 = w16 , key = mmdit_ckpt )
558
574
self .encoder = load_vae_encoder (float16 = False , key = mmdit_ckpt )
559
575
self .latent_format = FluxLatentFormat ()
560
-
561
- if not use_t5 :
562
- logger .warning ("FLUX can not be used without T5. Loading T5.." )
563
576
self .use_t5 = True
564
577
565
578
self .clip_l = load_text_encoder (
@@ -615,8 +628,22 @@ def __init__(self, model: DiffusionPipeline):
615
628
super ().__init__ ()
616
629
self .model = model
617
630
631
+ def cache_modulation_params (self , pooled_text_embeddings , sigmas ):
632
+ self .model .mmdit .cache_modulation_params (
633
+ pooled_text_embeddings , sigmas .astype (self .model .activation_dtype )
634
+ )
635
+
636
+ def clear_cache (self ):
637
+ self .model .mmdit .clear_modulation_params_cache ()
638
+
618
639
def __call__ (
619
- self , x_t , t , conditioning , cfg_weight : float = 7.5 , pooled_conditioning = None
640
+ self ,
641
+ x_t ,
642
+ timestep ,
643
+ sigma ,
644
+ conditioning ,
645
+ cfg_weight : float = 7.5 ,
646
+ pooled_conditioning = None ,
620
647
):
621
648
if cfg_weight <= 0 :
622
649
logger .debug ("CFG Weight disabled" )
@@ -625,23 +652,14 @@ def __call__(
625
652
x_t_mmdit = mx .concatenate ([x_t ] * 2 , axis = 0 ).astype (
626
653
self .model .activation_dtype
627
654
)
628
- t_mmdit = mx .broadcast_to (t , [len (x_t_mmdit )])
629
- timestep = self .model .sampler .timestep (t_mmdit ).astype (
630
- self .model .activation_dtype
631
- )
632
655
mmdit_input = {
633
656
"latent_image_embeddings" : x_t_mmdit ,
634
657
"token_level_text_embeddings" : mx .expand_dims (conditioning , 2 ),
635
- "pooled_text_embeddings" : mx .expand_dims (
636
- mx .expand_dims (pooled_conditioning , 1 ), 1
637
- ),
638
- "timestep" : timestep ,
658
+ "timestep" : mx .broadcast_to (timestep , [len (x_t_mmdit )]),
639
659
}
640
660
641
661
mmdit_output = self .model .mmdit (** mmdit_input )
642
- eps_pred = self .model .sampler .calculate_denoised (
643
- t_mmdit , mmdit_output , x_t_mmdit
644
- )
662
+ eps_pred = self .model .sampler .calculate_denoised (sigma , mmdit_output , x_t_mmdit )
645
663
if cfg_weight <= 0 :
646
664
return eps_pred
647
665
else :
@@ -691,20 +709,28 @@ def to_d(x, sigma, denoised):
691
709
def sample_euler (model : CFGDenoiser , x , sigmas , extra_args = None ):
692
710
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
693
711
extra_args = {} if extra_args is None else extra_args
694
- s_in = mx . ones ([ x . shape [ 0 ]])
712
+
695
713
from tqdm import trange
696
714
697
715
t = trange (len (sigmas ) - 1 )
716
+
717
+ timesteps = model .model .sampler .timestep (sigmas ).astype (
718
+ model .model .activation_dtype
719
+ )
720
+ model .cache_modulation_params (extra_args .pop ("pooled_conditioning" ), timesteps )
721
+
698
722
iter_time = []
699
723
for i in t :
700
724
start_time = t .format_dict ["elapsed" ]
701
- sigma_hat = sigmas [i ]
702
- denoised = model (x , sigma_hat * s_in , ** extra_args )
703
- d = to_d (x , sigma_hat , denoised )
704
- dt = sigmas [i + 1 ] - sigma_hat
725
+ denoised = model (x , timesteps [i ], sigmas [i ], ** extra_args )
726
+ d = to_d (x , sigmas [i ], denoised )
727
+ dt = sigmas [i + 1 ] - sigmas [i ]
705
728
# Euler method
706
729
x = x + d * dt
707
730
mx .eval (x )
708
731
end_time = t .format_dict ["elapsed" ]
709
732
iter_time .append (round ((end_time - start_time ), 3 ))
733
+
734
+ # model.clear_cache()
735
+
710
736
return x , iter_time
0 commit comments