Skip to content

Commit 4c2bde3

Browse files
authored
Merge pull request #11 from argmaxinc/flux
FLUX-1.schnell
2 parents aacbd8f + eb8ae9a commit 4c2bde3

File tree

11 files changed

+1066
-223
lines changed

11 files changed

+1066
-223
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ pipeline = DiffusionPipeline(
9393
w16=True,
9494
shift=3.0,
9595
use_t5=False,
96-
model_size="2b",
96+
model_version="2b",
9797
low_memory_mode=False,
9898
a16=True,
9999
)

python/src/diffusionkit/mlx/__init__.py

Lines changed: 153 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from .model_io import (
2121
_DEFAULT_MODEL,
22+
load_flux,
2223
load_mmdit,
2324
load_t5_encoder,
2425
load_t5_tokenizer,
@@ -27,13 +28,14 @@
2728
load_vae_decoder,
2829
load_vae_encoder,
2930
)
30-
from .sampler import ModelSamplingDiscreteFlow
31+
from .sampler import FluxSampler, ModelSamplingDiscreteFlow
3132

3233
logger = get_logger(__name__)
3334

3435
MMDIT_CKPT = {
35-
"2b": "mmdit_2b",
36-
"8b": "models/sd3_8b_beta.safetensors",
36+
"stable-diffusion-3-medium": "stabilityai/stable-diffusion-3-medium",
37+
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
38+
"FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
3739
}
3840

3941

@@ -44,21 +46,29 @@ def __init__(
4446
w16: bool = False,
4547
shift: float = 1.0,
4648
use_t5: bool = True,
47-
model_size: str = "2b",
49+
model_version: str = "stable-diffusion-3-medium",
4850
low_memory_mode: bool = True,
4951
a16: bool = False,
5052
local_ckpt=None,
5153
):
5254
model_io.LOCAl_SD3_CKPT = local_ckpt
53-
self.dtype = mx.float16 if w16 else mx.float32
54-
self.activation_dtype = mx.float16 if a16 else mx.float32
55+
self.float16_dtype = mx.float16
56+
model_io._FLOAT16 = self.float16_dtype
57+
self.dtype = self.float16_dtype if w16 else mx.float32
58+
self.activation_dtype = self.float16_dtype if a16 else mx.float32
5559
self.use_t5 = use_t5
56-
mmdit_ckpt = MMDIT_CKPT[model_size]
60+
mmdit_ckpt = MMDIT_CKPT[model_version]
5761
self.low_memory_mode = low_memory_mode
58-
self.mmdit = load_mmdit(float16=w16, model_key=mmdit_ckpt)
62+
self.mmdit = load_mmdit(
63+
float16=w16,
64+
key=mmdit_ckpt,
65+
model_key=model_version,
66+
low_memory_mode=low_memory_mode,
67+
)
5968
self.sampler = ModelSamplingDiscreteFlow(shift=shift)
60-
self.decoder = load_vae_decoder(float16=w16)
61-
self.encoder = load_vae_encoder(float16=False)
69+
self.decoder = load_vae_decoder(float16=w16, key=mmdit_ckpt)
70+
self.encoder = load_vae_encoder(float16=False, key=mmdit_ckpt)
71+
self.latent_format = SD3LatentFormat()
6272

6373
self.clip_l = load_text_encoder(
6474
model,
@@ -90,7 +100,7 @@ def __init__(
90100
def set_up_t5(self):
91101
if self.t5_encoder is None:
92102
self.t5_encoder = load_t5_encoder(
93-
float16=True if self.dtype == mx.float16 else False,
103+
float16=True if self.dtype == self.float16_dtype else False,
94104
low_memory_mode=self.low_memory_mode,
95105
)
96106
if self.t5_tokenizer is None:
@@ -110,9 +120,10 @@ def unload_t5(self):
110120
def ensure_models_are_loaded(self):
111121
mx.eval(self.mmdit.parameters())
112122
mx.eval(self.clip_l.parameters())
113-
mx.eval(self.clip_g.parameters())
114123
mx.eval(self.decoder.parameters())
115-
if self.use_t5:
124+
if hasattr(self, "clip_g"):
125+
mx.eval(self.clip_g.parameters())
126+
if hasattr(self, "t5_encoder") and self.use_t5:
116127
mx.eval(self.t5_encoder.parameters())
117128

118129
def _tokenize(self, tokenizer, text: str, negative_text: Optional[str] = None):
@@ -213,7 +224,7 @@ def denoise_latents(
213224
denoise = 1.0
214225
else:
215226
x_T = self.encode_image_to_latents(image_path, seed=seed)
216-
x_T = SD3LatentFormat().process_in(x_T)
227+
x_T = self.latent_format.process_in(x_T)
217228
noise = self.get_noise(seed, x_T)
218229
sigmas = self.get_sigmas(self.sampler, num_steps)
219230
sigmas = sigmas[int(num_steps * (1 - denoise)) :]
@@ -228,7 +239,9 @@ def denoise_latents(
228239
latent, iter_time = sample_euler(
229240
CFGDenoiser(self), noise_scaled, sigmas, extra_args=extra_args
230241
)
231-
latent = SD3LatentFormat().process_out(latent)
242+
243+
latent = self.latent_format.process_out(latent)
244+
232245
return latent, iter_time
233246

234247
def generate_image(
@@ -305,9 +318,11 @@ def generate_image(
305318

306319
# unload T5 and CLIP models after obtaining conditioning in low memory mode
307320
if self.low_memory_mode:
308-
del self.clip_g
321+
if hasattr(self, "t5_encoder"):
322+
del self.t5_encoder
323+
if hasattr(self, "clip_g"):
324+
del self.clip_g
309325
del self.clip_l
310-
del self.t5_encoder
311326
gc.collect()
312327

313328
logger.debug(f"Conditioning dtype before casting: {conditioning.dtype}")
@@ -406,7 +421,7 @@ def generate_image(
406421
logger.info(
407422
f"Pre decode active memory: {log['decoding']['pre']['active_memory']}GB"
408423
)
409-
424+
latents = latents.astype(mx.float32)
410425
decoded = self.decode_latents_to_image(latents)
411426
mx.eval(decoded)
412427

@@ -447,6 +462,16 @@ def generate_image(
447462

448463
return Image.fromarray(np.array(x)), log
449464

465+
def generate_ids(self, latent_size: Tuple[int]):
466+
h, w = latent_size
467+
img_ids = mx.zeros((h // 2, w // 2, 3))
468+
img_ids[..., 1] = img_ids[..., 1] + mx.arange(h // 2)[:, None]
469+
img_ids[..., 2] = img_ids[..., 2] + mx.arange(w // 2)[None, :]
470+
img_ids = img_ids.reshape(1, -1, 3)
471+
472+
txt_ids = mx.zeros((1, 256, 3)) # Hardcoded to context length of T5
473+
return img_ids, txt_ids
474+
450475
def read_image(self, image_path: str):
451476
# Read the image
452477
img = Image.open(image_path)
@@ -473,12 +498,15 @@ def get_noise(self, seed, x_T):
473498
def get_sigmas(self, sampler, num_steps: int):
474499
start = sampler.timestep(sampler.sigma_max).item()
475500
end = sampler.timestep(sampler.sigma_min).item()
501+
if isinstance(sampler, FluxSampler):
502+
num_steps += 1
476503
timesteps = mx.linspace(start, end, num_steps)
477504
sigs = []
478505
for x in range(len(timesteps)):
479506
ts = timesteps[x]
480507
sigs.append(sampler.sigma(ts))
481-
sigs += [0.0]
508+
if not isinstance(sampler, FluxSampler):
509+
sigs += [0.0]
482510
return mx.array(sigs)
483511

484512
def get_empty_latent(self, *shape):
@@ -505,6 +533,81 @@ def encode_image_to_latents(self, image_path: str, seed):
505533
return mean + std * noise
506534

507535

536+
class FluxPipeline(DiffusionPipeline):
537+
def __init__(
538+
self,
539+
model: str = _DEFAULT_MODEL,
540+
w16: bool = False,
541+
shift: float = 1.0,
542+
use_t5: bool = True,
543+
model_version: str = "FLUX.1-schnell",
544+
low_memory_mode: bool = True,
545+
a16: bool = False,
546+
local_ckpt=None,
547+
):
548+
model_io.LOCAl_SD3_CKPT = local_ckpt
549+
self.float16_dtype = mx.bfloat16
550+
model_io._FLOAT16 = self.float16_dtype
551+
self.dtype = self.float16_dtype if w16 else mx.float32
552+
self.activation_dtype = self.float16_dtype if a16 else mx.float32
553+
mmdit_ckpt = MMDIT_CKPT[model_version]
554+
self.low_memory_mode = low_memory_mode
555+
self.mmdit = load_flux(float16=w16, low_memory_mode=low_memory_mode)
556+
self.sampler = FluxSampler(shift=shift)
557+
self.decoder = load_vae_decoder(float16=w16, key=mmdit_ckpt)
558+
self.encoder = load_vae_encoder(float16=False, key=mmdit_ckpt)
559+
self.latent_format = FluxLatentFormat()
560+
561+
if not use_t5:
562+
logger.warning("FLUX can not be used without T5. Loading T5..")
563+
self.use_t5 = True
564+
565+
self.clip_l = load_text_encoder(
566+
model,
567+
w16,
568+
model_key="clip_l",
569+
)
570+
self.tokenizer_l = load_tokenizer(
571+
model,
572+
merges_key="tokenizer_l_merges",
573+
vocab_key="tokenizer_l_vocab",
574+
pad_with_eos=True,
575+
)
576+
self.t5_encoder = None
577+
self.t5_tokenizer = None
578+
if self.use_t5:
579+
self.set_up_t5()
580+
581+
def encode_text(
582+
self,
583+
text: str,
584+
cfg_weight: float = 7.5,
585+
negative_text: str = "",
586+
):
587+
tokens_l = self._tokenize(
588+
self.tokenizer_l,
589+
text,
590+
(negative_text if cfg_weight > 1 else None),
591+
)
592+
conditioning_l = self.clip_l(tokens_l[[0], :]) # Ignore negative text
593+
pooled_conditioning = conditioning_l.pooled_output
594+
595+
tokens_t5 = self._tokenize(
596+
self.t5_tokenizer,
597+
text,
598+
(negative_text if cfg_weight > 1 else None),
599+
)
600+
padded_tokens_t5 = mx.zeros((1, 256)).astype(tokens_t5.dtype)
601+
padded_tokens_t5[:, : tokens_t5.shape[1]] = tokens_t5[
602+
[0], :
603+
] # Ignore negative text
604+
t5_conditioning = self.t5_encoder(padded_tokens_t5)
605+
mx.eval(t5_conditioning)
606+
conditioning = t5_conditioning
607+
608+
return conditioning, pooled_conditioning
609+
610+
508611
class CFGDenoiser(nn.Module):
509612
"""Helper for applying CFG Scaling to diffusion outputs"""
510613

@@ -515,9 +618,13 @@ def __init__(self, model: DiffusionPipeline):
515618
def __call__(
516619
self, x_t, t, conditioning, cfg_weight: float = 7.5, pooled_conditioning=None
517620
):
518-
x_t_mmdit = mx.concatenate([x_t] * 2, axis=0).astype(
519-
self.model.activation_dtype
520-
)
621+
if cfg_weight <= 0:
622+
logger.debug("CFG Weight disabled")
623+
x_t_mmdit = x_t.astype(self.model.activation_dtype)
624+
else:
625+
x_t_mmdit = mx.concatenate([x_t] * 2, axis=0).astype(
626+
self.model.activation_dtype
627+
)
521628
t_mmdit = mx.broadcast_to(t, [len(x_t_mmdit)])
522629
timestep = self.model.sampler.timestep(t_mmdit).astype(
523630
self.model.activation_dtype
@@ -530,21 +637,24 @@ def __call__(
530637
),
531638
"timestep": timestep,
532639
}
640+
533641
mmdit_output = self.model.mmdit(**mmdit_input)
534642
eps_pred = self.model.sampler.calculate_denoised(
535643
t_mmdit, mmdit_output, x_t_mmdit
536644
)
537-
538-
eps_text, eps_neg = eps_pred.split(2)
539-
return eps_neg + cfg_weight * (eps_text - eps_neg)
645+
if cfg_weight <= 0:
646+
return eps_pred
647+
else:
648+
eps_text, eps_neg = eps_pred.split(2)
649+
return eps_neg + cfg_weight * (eps_text - eps_neg)
540650

541651

542-
class SD3LatentFormat:
543-
"""Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift"""
652+
class LatentFormat:
653+
"""Base class for latent format conversion"""
544654

545655
def __init__(self):
546-
self.scale_factor = 1.5305
547-
self.shift_factor = 0.0609
656+
self.scale_factor = 1.0
657+
self.shift_factor = 0.0
548658

549659
def process_in(self, latent):
550660
return (latent - self.shift_factor) * self.scale_factor
@@ -553,6 +663,20 @@ def process_out(self, latent):
553663
return (latent / self.scale_factor) + self.shift_factor
554664

555665

666+
class SD3LatentFormat(LatentFormat):
667+
def __init__(self):
668+
super().__init__()
669+
self.scale_factor = 1.5305
670+
self.shift_factor = 0.0609
671+
672+
673+
class FluxLatentFormat(LatentFormat):
674+
def __init__(self):
675+
super().__init__()
676+
self.scale_factor = 0.3611
677+
self.shift_factor = 0.1159
678+
679+
556680
def append_dims(x, target_dims):
557681
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
558682
dims_to_append = target_dims - x.ndim

python/src/diffusionkit/mlx/clip.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def __init__(self, config: CLIPTextModelConfig):
8383
def _get_mask(self, N, dtype):
8484
indices = mx.arange(N)
8585
mask = indices[:, None] < indices[None]
86-
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
86+
mask = mask.astype(dtype) * (
87+
-6e4 if (dtype == mx.bfloat16 or dtype == mx.float16) else -1e9
88+
)
8789
return mask
8890

8991
def __call__(self, x):

0 commit comments

Comments
 (0)