Skip to content

Commit 76405ec

Browse files
authored
Merge pull request #15 from argmaxinc/atila/flux_cache
FLUX memory optimizations
2 parents 730a876 + b332a04 commit 76405ec

File tree

5 files changed

+205
-115
lines changed

5 files changed

+205
-115
lines changed

python/src/diffusionkit/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import os
2+
3+
os.environ["TOKENIZERS_PARALLELISM"] = "false"

python/src/diffusionkit/mlx/__init__.py

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
import gc
99
import math
1010
import time
11+
from pprint import pprint
1112
from typing import Optional, Tuple
1213

1314
import mlx.core as mx
1415
import mlx.nn as nn
1516
import numpy as np
17+
from argmaxtools.test_utils import AppleSiliconContextMixin, InferenceContextSpec
1618
from argmaxtools.utils import get_logger
1719
from diffusionkit.utils import bytes2gigabytes
1820
from PIL import Image
@@ -39,6 +41,14 @@
3941
}
4042

4143

44+
class DiffusionKitInferenceContext(AppleSiliconContextMixin, InferenceContextSpec):
45+
def code_spec(self):
46+
return {}
47+
48+
def model_spec(self):
49+
return {}
50+
51+
4252
class DiffusionPipeline:
4353
def __init__(
4454
self,
@@ -292,6 +302,11 @@ def generate_image(
292302
logger.info(
293303
f"Pre text encoding active memory: {log['text_encoding']['pre']['active_memory']}GB"
294304
)
305+
306+
# FIXME(arda): Need the same for CLIP models (low memory mode will not succeed a second time otherwise)
307+
if not hasattr(self, "t5"):
308+
self.set_up_t5()
309+
295310
conditioning, pooled_conditioning = self.encode_text(
296311
text, cfg_weight, negative_text
297312
)
@@ -442,8 +457,19 @@ def generate_image(
442457
logger.info(
443458
f"Post decode active memory: {log['decoding']['post']['active_memory']}GB"
444459
)
445-
logger.info(f"Decoding time: {log['decoding']['time']}s")
446-
logger.info(f"Peak memory: {log['peak_memory']}GB")
460+
461+
logger.info("============= Summary =============")
462+
logger.info(f"Text encoder: {log['text_encoding']['time']:.1f}s")
463+
logger.info(f"Denoising: {log['denoising']['time']:.1f}s")
464+
logger.info(f"Image decoder: {log['decoding']['time']:.1f}s")
465+
logger.info(f"Peak memory: {log['peak_memory']:.1f}GB")
466+
467+
logger.info("============= Inference Context =============")
468+
ic = DiffusionKitInferenceContext()
469+
logger.info("Operating System:")
470+
pprint(ic.os_spec())
471+
logger.info("Device:")
472+
pprint(ic.device_spec())
447473

448474
# unload VAE Decoder model after decoding in low memory mode
449475
if self.low_memory_mode:
@@ -462,24 +488,14 @@ def generate_image(
462488

463489
return Image.fromarray(np.array(x)), log
464490

465-
def generate_ids(self, latent_size: Tuple[int]):
466-
h, w = latent_size
467-
img_ids = mx.zeros((h // 2, w // 2, 3))
468-
img_ids[..., 1] = img_ids[..., 1] + mx.arange(h // 2)[:, None]
469-
img_ids[..., 2] = img_ids[..., 2] + mx.arange(w // 2)[None, :]
470-
img_ids = img_ids.reshape(1, -1, 3)
471-
472-
txt_ids = mx.zeros((1, 256, 3)) # Hardcoded to context length of T5
473-
return img_ids, txt_ids
474-
475491
def read_image(self, image_path: str):
476492
# Read the image
477493
img = Image.open(image_path)
478494

479495
# Make sure image shape is divisible by 64
480496
W, H = (dim - dim % 64 for dim in (img.width, img.height))
481497
if W != img.width or H != img.height:
482-
print(
498+
logger.warning(
483499
f"Warning: image shape is not divisible by 64, downsampling to {W}x{H}"
484500
)
485501
img = img.resize((W, H), Image.LANCZOS) # use desired downsampling filter
@@ -557,9 +573,6 @@ def __init__(
557573
self.decoder = load_vae_decoder(float16=w16, key=mmdit_ckpt)
558574
self.encoder = load_vae_encoder(float16=False, key=mmdit_ckpt)
559575
self.latent_format = FluxLatentFormat()
560-
561-
if not use_t5:
562-
logger.warning("FLUX can not be used without T5. Loading T5..")
563576
self.use_t5 = True
564577

565578
self.clip_l = load_text_encoder(
@@ -615,8 +628,22 @@ def __init__(self, model: DiffusionPipeline):
615628
super().__init__()
616629
self.model = model
617630

631+
def cache_modulation_params(self, pooled_text_embeddings, sigmas):
632+
self.model.mmdit.cache_modulation_params(
633+
pooled_text_embeddings, sigmas.astype(self.model.activation_dtype)
634+
)
635+
636+
def clear_cache(self):
637+
self.model.mmdit.clear_modulation_params_cache()
638+
618639
def __call__(
619-
self, x_t, t, conditioning, cfg_weight: float = 7.5, pooled_conditioning=None
640+
self,
641+
x_t,
642+
timestep,
643+
sigma,
644+
conditioning,
645+
cfg_weight: float = 7.5,
646+
pooled_conditioning=None,
620647
):
621648
if cfg_weight <= 0:
622649
logger.debug("CFG Weight disabled")
@@ -625,23 +652,14 @@ def __call__(
625652
x_t_mmdit = mx.concatenate([x_t] * 2, axis=0).astype(
626653
self.model.activation_dtype
627654
)
628-
t_mmdit = mx.broadcast_to(t, [len(x_t_mmdit)])
629-
timestep = self.model.sampler.timestep(t_mmdit).astype(
630-
self.model.activation_dtype
631-
)
632655
mmdit_input = {
633656
"latent_image_embeddings": x_t_mmdit,
634657
"token_level_text_embeddings": mx.expand_dims(conditioning, 2),
635-
"pooled_text_embeddings": mx.expand_dims(
636-
mx.expand_dims(pooled_conditioning, 1), 1
637-
),
638-
"timestep": timestep,
658+
"timestep": mx.broadcast_to(timestep, [len(x_t_mmdit)]),
639659
}
640660

641661
mmdit_output = self.model.mmdit(**mmdit_input)
642-
eps_pred = self.model.sampler.calculate_denoised(
643-
t_mmdit, mmdit_output, x_t_mmdit
644-
)
662+
eps_pred = self.model.sampler.calculate_denoised(sigma, mmdit_output, x_t_mmdit)
645663
if cfg_weight <= 0:
646664
return eps_pred
647665
else:
@@ -691,20 +709,28 @@ def to_d(x, sigma, denoised):
691709
def sample_euler(model: CFGDenoiser, x, sigmas, extra_args=None):
692710
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
693711
extra_args = {} if extra_args is None else extra_args
694-
s_in = mx.ones([x.shape[0]])
712+
695713
from tqdm import trange
696714

697715
t = trange(len(sigmas) - 1)
716+
717+
timesteps = model.model.sampler.timestep(sigmas).astype(
718+
model.model.activation_dtype
719+
)
720+
model.cache_modulation_params(extra_args.pop("pooled_conditioning"), timesteps)
721+
698722
iter_time = []
699723
for i in t:
700724
start_time = t.format_dict["elapsed"]
701-
sigma_hat = sigmas[i]
702-
denoised = model(x, sigma_hat * s_in, **extra_args)
703-
d = to_d(x, sigma_hat, denoised)
704-
dt = sigmas[i + 1] - sigma_hat
725+
denoised = model(x, timesteps[i], sigmas[i], **extra_args)
726+
d = to_d(x, sigmas[i], denoised)
727+
dt = sigmas[i + 1] - sigmas[i]
705728
# Euler method
706729
x = x + d * dt
707730
mx.eval(x)
708731
end_time = t.format_dict["elapsed"]
709732
iter_time.append(round((end_time - start_time), 3))
733+
734+
# model.clear_cache()
735+
710736
return x, iter_time

0 commit comments

Comments
 (0)