19
19
20
20
from .model_io import (
21
21
_DEFAULT_MODEL ,
22
+ load_flux ,
22
23
load_mmdit ,
23
24
load_t5_encoder ,
24
25
load_t5_tokenizer ,
27
28
load_vae_decoder ,
28
29
load_vae_encoder ,
29
30
)
30
- from .sampler import ModelSamplingDiscreteFlow
31
+ from .sampler import FluxSampler , ModelSamplingDiscreteFlow
31
32
32
33
logger = get_logger (__name__ )
33
34
34
35
MMDIT_CKPT = {
35
- "2b" : "mmdit_2b" ,
36
- "8b" : "models/sd3_8b_beta.safetensors" ,
36
+ "stable-diffusion-3-medium" : "stabilityai/stable-diffusion-3-medium" ,
37
+ "sd3-8b-unreleased" : "models/sd3_8b_beta.safetensors" , # unreleased
38
+ "FLUX.1-schnell" : "argmaxinc/mlx-FLUX.1-schnell" ,
37
39
}
38
40
39
41
@@ -44,21 +46,29 @@ def __init__(
44
46
w16 : bool = False ,
45
47
shift : float = 1.0 ,
46
48
use_t5 : bool = True ,
47
- model_size : str = "2b " ,
49
+ model_version : str = "stable-diffusion-3-medium " ,
48
50
low_memory_mode : bool = True ,
49
51
a16 : bool = False ,
50
52
local_ckpt = None ,
51
53
):
52
54
model_io .LOCAl_SD3_CKPT = local_ckpt
53
- self .dtype = mx .float16 if w16 else mx .float32
54
- self .activation_dtype = mx .float16 if a16 else mx .float32
55
+ self .float16_dtype = mx .float16
56
+ model_io ._FLOAT16 = self .float16_dtype
57
+ self .dtype = self .float16_dtype if w16 else mx .float32
58
+ self .activation_dtype = self .float16_dtype if a16 else mx .float32
55
59
self .use_t5 = use_t5
56
- mmdit_ckpt = MMDIT_CKPT [model_size ]
60
+ mmdit_ckpt = MMDIT_CKPT [model_version ]
57
61
self .low_memory_mode = low_memory_mode
58
- self .mmdit = load_mmdit (float16 = w16 , model_key = mmdit_ckpt )
62
+ self .mmdit = load_mmdit (
63
+ float16 = w16 ,
64
+ key = mmdit_ckpt ,
65
+ model_key = model_version ,
66
+ low_memory_mode = low_memory_mode ,
67
+ )
59
68
self .sampler = ModelSamplingDiscreteFlow (shift = shift )
60
- self .decoder = load_vae_decoder (float16 = w16 )
61
- self .encoder = load_vae_encoder (float16 = False )
69
+ self .decoder = load_vae_decoder (float16 = w16 , key = mmdit_ckpt )
70
+ self .encoder = load_vae_encoder (float16 = False , key = mmdit_ckpt )
71
+ self .latent_format = SD3LatentFormat ()
62
72
63
73
self .clip_l = load_text_encoder (
64
74
model ,
@@ -90,7 +100,7 @@ def __init__(
90
100
def set_up_t5 (self ):
91
101
if self .t5_encoder is None :
92
102
self .t5_encoder = load_t5_encoder (
93
- float16 = True if self .dtype == mx . float16 else False ,
103
+ float16 = True if self .dtype == self . float16_dtype else False ,
94
104
low_memory_mode = self .low_memory_mode ,
95
105
)
96
106
if self .t5_tokenizer is None :
@@ -110,9 +120,10 @@ def unload_t5(self):
110
120
def ensure_models_are_loaded (self ):
111
121
mx .eval (self .mmdit .parameters ())
112
122
mx .eval (self .clip_l .parameters ())
113
- mx .eval (self .clip_g .parameters ())
114
123
mx .eval (self .decoder .parameters ())
115
- if self .use_t5 :
124
+ if hasattr (self , "clip_g" ):
125
+ mx .eval (self .clip_g .parameters ())
126
+ if hasattr (self , "t5_encoder" ) and self .use_t5 :
116
127
mx .eval (self .t5_encoder .parameters ())
117
128
118
129
def _tokenize (self , tokenizer , text : str , negative_text : Optional [str ] = None ):
@@ -213,7 +224,7 @@ def denoise_latents(
213
224
denoise = 1.0
214
225
else :
215
226
x_T = self .encode_image_to_latents (image_path , seed = seed )
216
- x_T = SD3LatentFormat () .process_in (x_T )
227
+ x_T = self . latent_format .process_in (x_T )
217
228
noise = self .get_noise (seed , x_T )
218
229
sigmas = self .get_sigmas (self .sampler , num_steps )
219
230
sigmas = sigmas [int (num_steps * (1 - denoise )) :]
@@ -228,7 +239,9 @@ def denoise_latents(
228
239
latent , iter_time = sample_euler (
229
240
CFGDenoiser (self ), noise_scaled , sigmas , extra_args = extra_args
230
241
)
231
- latent = SD3LatentFormat ().process_out (latent )
242
+
243
+ latent = self .latent_format .process_out (latent )
244
+
232
245
return latent , iter_time
233
246
234
247
def generate_image (
@@ -305,9 +318,11 @@ def generate_image(
305
318
306
319
# unload T5 and CLIP models after obtaining conditioning in low memory mode
307
320
if self .low_memory_mode :
308
- del self .clip_g
321
+ if hasattr (self , "t5_encoder" ):
322
+ del self .t5_encoder
323
+ if hasattr (self , "clip_g" ):
324
+ del self .clip_g
309
325
del self .clip_l
310
- del self .t5_encoder
311
326
gc .collect ()
312
327
313
328
logger .debug (f"Conditioning dtype before casting: { conditioning .dtype } " )
@@ -406,7 +421,7 @@ def generate_image(
406
421
logger .info (
407
422
f"Pre decode active memory: { log ['decoding' ]['pre' ]['active_memory' ]} GB"
408
423
)
409
-
424
+ latents = latents . astype ( mx . float32 )
410
425
decoded = self .decode_latents_to_image (latents )
411
426
mx .eval (decoded )
412
427
@@ -447,6 +462,16 @@ def generate_image(
447
462
448
463
return Image .fromarray (np .array (x )), log
449
464
465
+ def generate_ids (self , latent_size : Tuple [int ]):
466
+ h , w = latent_size
467
+ img_ids = mx .zeros ((h // 2 , w // 2 , 3 ))
468
+ img_ids [..., 1 ] = img_ids [..., 1 ] + mx .arange (h // 2 )[:, None ]
469
+ img_ids [..., 2 ] = img_ids [..., 2 ] + mx .arange (w // 2 )[None , :]
470
+ img_ids = img_ids .reshape (1 , - 1 , 3 )
471
+
472
+ txt_ids = mx .zeros ((1 , 256 , 3 )) # Hardcoded to context length of T5
473
+ return img_ids , txt_ids
474
+
450
475
def read_image (self , image_path : str ):
451
476
# Read the image
452
477
img = Image .open (image_path )
@@ -473,12 +498,15 @@ def get_noise(self, seed, x_T):
473
498
def get_sigmas (self , sampler , num_steps : int ):
474
499
start = sampler .timestep (sampler .sigma_max ).item ()
475
500
end = sampler .timestep (sampler .sigma_min ).item ()
501
+ if isinstance (sampler , FluxSampler ):
502
+ num_steps += 1
476
503
timesteps = mx .linspace (start , end , num_steps )
477
504
sigs = []
478
505
for x in range (len (timesteps )):
479
506
ts = timesteps [x ]
480
507
sigs .append (sampler .sigma (ts ))
481
- sigs += [0.0 ]
508
+ if not isinstance (sampler , FluxSampler ):
509
+ sigs += [0.0 ]
482
510
return mx .array (sigs )
483
511
484
512
def get_empty_latent (self , * shape ):
@@ -505,6 +533,81 @@ def encode_image_to_latents(self, image_path: str, seed):
505
533
return mean + std * noise
506
534
507
535
536
+ class FluxPipeline (DiffusionPipeline ):
537
+ def __init__ (
538
+ self ,
539
+ model : str = _DEFAULT_MODEL ,
540
+ w16 : bool = False ,
541
+ shift : float = 1.0 ,
542
+ use_t5 : bool = True ,
543
+ model_version : str = "FLUX.1-schnell" ,
544
+ low_memory_mode : bool = True ,
545
+ a16 : bool = False ,
546
+ local_ckpt = None ,
547
+ ):
548
+ model_io .LOCAl_SD3_CKPT = local_ckpt
549
+ self .float16_dtype = mx .bfloat16
550
+ model_io ._FLOAT16 = self .float16_dtype
551
+ self .dtype = self .float16_dtype if w16 else mx .float32
552
+ self .activation_dtype = self .float16_dtype if a16 else mx .float32
553
+ mmdit_ckpt = MMDIT_CKPT [model_version ]
554
+ self .low_memory_mode = low_memory_mode
555
+ self .mmdit = load_flux (float16 = w16 , low_memory_mode = low_memory_mode )
556
+ self .sampler = FluxSampler (shift = shift )
557
+ self .decoder = load_vae_decoder (float16 = w16 , key = mmdit_ckpt )
558
+ self .encoder = load_vae_encoder (float16 = False , key = mmdit_ckpt )
559
+ self .latent_format = FluxLatentFormat ()
560
+
561
+ if not use_t5 :
562
+ logger .warning ("FLUX can not be used without T5. Loading T5.." )
563
+ self .use_t5 = True
564
+
565
+ self .clip_l = load_text_encoder (
566
+ model ,
567
+ w16 ,
568
+ model_key = "clip_l" ,
569
+ )
570
+ self .tokenizer_l = load_tokenizer (
571
+ model ,
572
+ merges_key = "tokenizer_l_merges" ,
573
+ vocab_key = "tokenizer_l_vocab" ,
574
+ pad_with_eos = True ,
575
+ )
576
+ self .t5_encoder = None
577
+ self .t5_tokenizer = None
578
+ if self .use_t5 :
579
+ self .set_up_t5 ()
580
+
581
+ def encode_text (
582
+ self ,
583
+ text : str ,
584
+ cfg_weight : float = 7.5 ,
585
+ negative_text : str = "" ,
586
+ ):
587
+ tokens_l = self ._tokenize (
588
+ self .tokenizer_l ,
589
+ text ,
590
+ (negative_text if cfg_weight > 1 else None ),
591
+ )
592
+ conditioning_l = self .clip_l (tokens_l [[0 ], :]) # Ignore negative text
593
+ pooled_conditioning = conditioning_l .pooled_output
594
+
595
+ tokens_t5 = self ._tokenize (
596
+ self .t5_tokenizer ,
597
+ text ,
598
+ (negative_text if cfg_weight > 1 else None ),
599
+ )
600
+ padded_tokens_t5 = mx .zeros ((1 , 256 )).astype (tokens_t5 .dtype )
601
+ padded_tokens_t5 [:, : tokens_t5 .shape [1 ]] = tokens_t5 [
602
+ [0 ], :
603
+ ] # Ignore negative text
604
+ t5_conditioning = self .t5_encoder (padded_tokens_t5 )
605
+ mx .eval (t5_conditioning )
606
+ conditioning = t5_conditioning
607
+
608
+ return conditioning , pooled_conditioning
609
+
610
+
508
611
class CFGDenoiser (nn .Module ):
509
612
"""Helper for applying CFG Scaling to diffusion outputs"""
510
613
@@ -515,9 +618,13 @@ def __init__(self, model: DiffusionPipeline):
515
618
def __call__ (
516
619
self , x_t , t , conditioning , cfg_weight : float = 7.5 , pooled_conditioning = None
517
620
):
518
- x_t_mmdit = mx .concatenate ([x_t ] * 2 , axis = 0 ).astype (
519
- self .model .activation_dtype
520
- )
621
+ if cfg_weight <= 0 :
622
+ logger .debug ("CFG Weight disabled" )
623
+ x_t_mmdit = x_t .astype (self .model .activation_dtype )
624
+ else :
625
+ x_t_mmdit = mx .concatenate ([x_t ] * 2 , axis = 0 ).astype (
626
+ self .model .activation_dtype
627
+ )
521
628
t_mmdit = mx .broadcast_to (t , [len (x_t_mmdit )])
522
629
timestep = self .model .sampler .timestep (t_mmdit ).astype (
523
630
self .model .activation_dtype
@@ -530,21 +637,24 @@ def __call__(
530
637
),
531
638
"timestep" : timestep ,
532
639
}
640
+
533
641
mmdit_output = self .model .mmdit (** mmdit_input )
534
642
eps_pred = self .model .sampler .calculate_denoised (
535
643
t_mmdit , mmdit_output , x_t_mmdit
536
644
)
537
-
538
- eps_text , eps_neg = eps_pred .split (2 )
539
- return eps_neg + cfg_weight * (eps_text - eps_neg )
645
+ if cfg_weight <= 0 :
646
+ return eps_pred
647
+ else :
648
+ eps_text , eps_neg = eps_pred .split (2 )
649
+ return eps_neg + cfg_weight * (eps_text - eps_neg )
540
650
541
651
542
- class SD3LatentFormat :
543
- """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift """
652
+ class LatentFormat :
653
+ """Base class for latent format conversion """
544
654
545
655
def __init__ (self ):
546
- self .scale_factor = 1.5305
547
- self .shift_factor = 0.0609
656
+ self .scale_factor = 1.0
657
+ self .shift_factor = 0.0
548
658
549
659
def process_in (self , latent ):
550
660
return (latent - self .shift_factor ) * self .scale_factor
@@ -553,6 +663,20 @@ def process_out(self, latent):
553
663
return (latent / self .scale_factor ) + self .shift_factor
554
664
555
665
666
+ class SD3LatentFormat (LatentFormat ):
667
+ def __init__ (self ):
668
+ super ().__init__ ()
669
+ self .scale_factor = 1.5305
670
+ self .shift_factor = 0.0609
671
+
672
+
673
+ class FluxLatentFormat (LatentFormat ):
674
+ def __init__ (self ):
675
+ super ().__init__ ()
676
+ self .scale_factor = 0.3611
677
+ self .shift_factor = 0.1159
678
+
679
+
556
680
def append_dims (x , target_dims ):
557
681
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
558
682
dims_to_append = target_dims - x .ndim
0 commit comments