Skip to content

Commit 9ec5377

Browse files
authored
support flux fbcache (#117)
* support flux fbcache * fix fb cache dit init
1 parent baf5345 commit 9ec5377

File tree

13 files changed

+590
-208
lines changed

13 files changed

+590
-208
lines changed

diffsynth_engine/models/flux/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .flux_controlnet import FluxControlNet
55
from .flux_ipadapter import FluxIPAdapter
66
from .flux_redux import FluxRedux
7+
from .flux_dit_fbcache import FluxDiTFBCache
78

89
__all__ = [
910
"FluxRedux",
@@ -14,6 +15,7 @@
1415
"FluxTextEncoder2",
1516
"FluxVAEDecoder",
1617
"FluxVAEEncoder",
18+
"FluxDiTFBCache",
1719
"flux_dit_config",
1820
"flux_text_encoder_config",
1921
"flux_vae_config",
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import torch
2+
import numpy as np
3+
from typing import Dict, Optional
4+
5+
from diffsynth_engine.models.utils import no_init_weights
6+
from diffsynth_engine.utils.gguf import gguf_inference
7+
from diffsynth_engine.utils.fp8_linear import fp8_inference
8+
from diffsynth_engine.utils.parallel import (
9+
cfg_parallel,
10+
cfg_parallel_unshard,
11+
sequence_parallel,
12+
sequence_parallel_unshard,
13+
)
14+
from diffsynth_engine.utils import logging
15+
from diffsynth_engine.models.flux.flux_dit import FluxDiT
16+
17+
logger = logging.get_logger(__name__)
18+
19+
20+
class FluxDiTFBCache(FluxDiT):
21+
def __init__(
22+
self,
23+
in_channel: int = 64,
24+
attn_impl: Optional[str] = None,
25+
device: str = "cuda:0",
26+
dtype: torch.dtype = torch.bfloat16,
27+
relative_l1_threshold: float = 0.05,
28+
):
29+
super().__init__(in_channel=in_channel, attn_impl=attn_impl, device=device, dtype=dtype)
30+
self.relative_l1_threshold = relative_l1_threshold
31+
self.step_count = 0
32+
self.num_inference_steps = 0
33+
34+
def is_relative_l1_below_threshold(self, prev_residual, residual, threshold):
35+
if threshold <= 0.0:
36+
return False
37+
38+
if prev_residual.shape != residual.shape:
39+
return False
40+
41+
mean_diff = (prev_residual - residual).abs().mean()
42+
mean_prev_residual = prev_residual.abs().mean()
43+
diff = mean_diff / mean_prev_residual
44+
return diff.item() < threshold
45+
46+
def refresh_cache_status(self, num_inference_steps):
47+
self.step_count = 0
48+
self.num_inference_steps = num_inference_steps
49+
50+
def forward(
51+
self,
52+
hidden_states,
53+
timestep,
54+
prompt_emb,
55+
pooled_prompt_emb,
56+
image_emb,
57+
guidance,
58+
text_ids,
59+
image_ids=None,
60+
controlnet_double_block_output=None,
61+
controlnet_single_block_output=None,
62+
**kwargs,
63+
):
64+
h, w = hidden_states.shape[-2:]
65+
if image_ids is None:
66+
image_ids = self.prepare_image_ids(hidden_states)
67+
controlnet_double_block_output = (
68+
controlnet_double_block_output if controlnet_double_block_output is not None else ()
69+
)
70+
controlnet_single_block_output = (
71+
controlnet_single_block_output if controlnet_single_block_output is not None else ()
72+
)
73+
74+
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
75+
use_cfg = hidden_states.shape[0] > 1
76+
with (
77+
fp8_inference(fp8_linear_enabled),
78+
gguf_inference(),
79+
cfg_parallel(
80+
(
81+
hidden_states,
82+
timestep,
83+
prompt_emb,
84+
pooled_prompt_emb,
85+
image_emb,
86+
guidance,
87+
text_ids,
88+
image_ids,
89+
*controlnet_double_block_output,
90+
*controlnet_single_block_output,
91+
),
92+
use_cfg=use_cfg,
93+
),
94+
):
95+
# warning: keep the order of time_embedding + guidance_embedding + pooled_text_embedding
96+
# addition of floating point numbers does not meet commutative law
97+
conditioning = self.time_embedder(timestep, hidden_states.dtype)
98+
if self.guidance_embedder is not None:
99+
guidance = guidance * 1000
100+
conditioning += self.guidance_embedder(guidance, hidden_states.dtype)
101+
conditioning += self.pooled_text_embedder(pooled_prompt_emb)
102+
rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
103+
text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
104+
image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
105+
hidden_states = self.patchify(hidden_states)
106+
107+
with sequence_parallel(
108+
(
109+
hidden_states,
110+
prompt_emb,
111+
text_rope_emb,
112+
image_rope_emb,
113+
*controlnet_double_block_output,
114+
*controlnet_single_block_output,
115+
),
116+
seq_dims=(
117+
1,
118+
1,
119+
2,
120+
2,
121+
*(1 for _ in controlnet_double_block_output),
122+
*(1 for _ in controlnet_single_block_output),
123+
),
124+
):
125+
hidden_states = self.x_embedder(hidden_states)
126+
prompt_emb = self.context_embedder(prompt_emb)
127+
rope_emb = torch.cat((text_rope_emb, image_rope_emb), dim=2)
128+
129+
# first block
130+
original_hidden_states = hidden_states
131+
hidden_states, prompt_emb = self.blocks[0](hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
132+
first_hidden_states_residual = hidden_states - original_hidden_states
133+
134+
(first_hidden_states_residual,) = sequence_parallel_unshard(
135+
(first_hidden_states_residual,), seq_dims=(1,), seq_lens=(h * w // 4,)
136+
)
137+
138+
if self.step_count == 0 or self.step_count == (self.num_inference_steps - 1):
139+
should_calc = True
140+
else:
141+
skip = self.is_relative_l1_below_threshold(
142+
first_hidden_states_residual,
143+
self.prev_first_hidden_states_residual,
144+
threshold=self.relative_l1_threshold,
145+
)
146+
should_calc = not skip
147+
self.step_count += 1
148+
149+
if not should_calc:
150+
hidden_states += self.previous_residual
151+
else:
152+
self.prev_first_hidden_states_residual = first_hidden_states_residual
153+
154+
first_hidden_states = hidden_states.clone()
155+
for i, block in enumerate(self.blocks[1:]):
156+
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, rope_emb, image_emb)
157+
if len(controlnet_double_block_output) > 0:
158+
interval_control = len(self.blocks) / len(controlnet_double_block_output)
159+
interval_control = int(np.ceil(interval_control))
160+
hidden_states = hidden_states + controlnet_double_block_output[i // interval_control]
161+
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
162+
for i, block in enumerate(self.single_blocks):
163+
hidden_states = block(hidden_states, conditioning, rope_emb, image_emb)
164+
if len(controlnet_single_block_output) > 0:
165+
interval_control = len(self.single_blocks) / len(controlnet_double_block_output)
166+
interval_control = int(np.ceil(interval_control))
167+
hidden_states = hidden_states + controlnet_single_block_output[i // interval_control]
168+
169+
hidden_states = hidden_states[:, prompt_emb.shape[1] :]
170+
171+
previous_residual = hidden_states - first_hidden_states
172+
self.previous_residual = previous_residual
173+
174+
hidden_states = self.final_norm_out(hidden_states, conditioning)
175+
hidden_states = self.final_proj_out(hidden_states)
176+
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
177+
178+
hidden_states = self.unpatchify(hidden_states, h, w)
179+
(hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
180+
181+
return hidden_states
182+
183+
@classmethod
184+
def from_state_dict(
185+
cls,
186+
state_dict: Dict[str, torch.Tensor],
187+
device: str,
188+
dtype: torch.dtype,
189+
in_channel: int = 64,
190+
attn_impl: Optional[str] = None,
191+
fb_cache_relative_l1_threshold: float = 0.05,
192+
):
193+
with no_init_weights():
194+
model = torch.nn.utils.skip_init(
195+
cls,
196+
device=device,
197+
dtype=dtype,
198+
in_channel=in_channel,
199+
attn_impl=attn_impl,
200+
fb_cache_relative_l1_threshold=fb_cache_relative_l1_threshold,
201+
)
202+
model = model.requires_grad_(False) # for loading gguf
203+
model.load_state_dict(state_dict, assign=True)
204+
model.to(device=device, dtype=dtype, non_blocking=True)
205+
return model

0 commit comments

Comments
 (0)