Skip to content

Commit cd71035

Browse files
committed
start flux.
1 parent f36ba9f commit cd71035

File tree

10 files changed

+520
-4
lines changed

10 files changed

+520
-4
lines changed

src/diffusers/hooks/_helpers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _register(cls):
107107
def _register_attention_processors_metadata():
108108
from ..models.attention_processor import AttnProcessor2_0
109109
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110+
from ..models.transformers.transformer_flux import FluxAttnProcessor
110111
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
111112

112113
# AttnProcessor2_0
@@ -132,6 +133,11 @@ def _register_attention_processors_metadata():
132133
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
133134
),
134135
)
136+
# FluxAttnProcessor
137+
AttentionProcessorRegistry.register(
138+
model_class=FluxAttnProcessor,
139+
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
140+
)
135141

136142

137143
def _register_transformer_blocks_metadata():
@@ -271,4 +277,6 @@ def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, *
271277
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
272278
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
273279
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
280+
# not sure what this is yet.
281+
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
274282
# fmt: on

src/diffusers/modular_pipelines/flux/__init__.py

Whitespace-only changes.

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 389 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, List, Tuple, Union
16+
17+
import numpy as np
18+
import PIL
19+
import torch
20+
21+
from ...configuration_utils import FrozenDict
22+
from ...models import AutoencoderKL
23+
from ...utils import logging
24+
from ...video_processor import VaeImageProcessor
25+
from ..modular_pipeline import PipelineBlock, PipelineState
26+
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
27+
28+
29+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30+
31+
32+
# Copied from diffusers.pipelines.flux.pipeline_flux._unpack_latents
33+
def _unpack_latents(latents, height, width, vae_scale_factor):
34+
batch_size, num_patches, channels = latents.shape
35+
36+
# VAE applies 8x compression on images but we must also account for packing which requires
37+
# latent height and width to be divisible by 2.
38+
height = 2 * (int(height) // (vae_scale_factor * 2))
39+
width = 2 * (int(width) // (vae_scale_factor * 2))
40+
41+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
42+
latents = latents.permute(0, 3, 1, 4, 2, 5)
43+
44+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
45+
46+
return latents
47+
48+
49+
class FluxDecodeStep(PipelineBlock):
50+
model_name = "flux"
51+
52+
@property
53+
def expected_components(self) -> List[ComponentSpec]:
54+
return [
55+
ComponentSpec("vae", AutoencoderKL),
56+
ComponentSpec(
57+
"image_processor",
58+
VaeImageProcessor,
59+
config=FrozenDict({"vae_scale_factor": 16}),
60+
default_creation_method="from_config",
61+
),
62+
]
63+
64+
@property
65+
def description(self) -> str:
66+
return "Step that decodes the denoised latents into images"
67+
68+
@property
69+
def inputs(self) -> List[Tuple[str, Any]]:
70+
return [
71+
InputParam("output_type", default="pil"),
72+
InputParam("height", default=1024),
73+
InputParam("width", default=1024),
74+
]
75+
76+
@property
77+
def intermediate_inputs(self) -> List[str]:
78+
return [
79+
InputParam(
80+
"latents",
81+
required=True,
82+
type_hint=torch.Tensor,
83+
description="The denoised latents from the denoising step",
84+
)
85+
]
86+
87+
@property
88+
def intermediate_outputs(self) -> List[str]:
89+
return [
90+
OutputParam(
91+
"images",
92+
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
93+
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
94+
)
95+
]
96+
97+
@torch.no_grad()
98+
def __call__(self, components, state: PipelineState) -> PipelineState:
99+
block_state = self.get_block_state(state)
100+
vae = components.vae
101+
102+
if not block_state.output_type == "latent":
103+
latents = block_state.latents
104+
latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor)
105+
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
106+
block_state.images = vae.decode(latents, return_dict=False)[0]
107+
block_state.images = components.image_processor.postprocess(
108+
block_state.images, output_type=block_state.output_type
109+
)
110+
else:
111+
block_state.images = block_state.latents
112+
113+
self.set_block_state(state, block_state)
114+
115+
return components, state

src/diffusers/modular_pipelines/flux/denoise.py

Whitespace-only changes.

src/diffusers/modular_pipelines/flux/encoders.py

Whitespace-only changes.

src/diffusers/modular_pipelines/flux/modular_blocks.py

Whitespace-only changes.

src/diffusers/modular_pipelines/flux/modular_pipeline.py

Whitespace-only changes.

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,15 @@
6161
[
6262
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
6363
("wan", "WanModularPipeline"),
64+
("flux", "FluxModularPipeline"),
6465
]
6566
)
6667

6768
MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
6869
[
6970
("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
7071
("WanModularPipeline", "WanAutoBlocks"),
72+
("FluxModularPipeline", "FluxAutoBlocks"),
7173
]
7274
)
7375

src/diffusers/pipelines/flux/pipeline_output.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
@dataclass
1212
class FluxPipelineOutput(BaseOutput):
1313
"""
14-
Output class for Stable Diffusion pipelines.
14+
Output class for Flux image generation pipelines.
1515
1616
Args:
17-
images (`List[PIL.Image.Image]` or `np.ndarray`)
18-
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
19-
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
17+
images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
18+
List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
19+
height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
20+
pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
21+
passed to the decoder.
2022
"""
2123

2224
images: Union[List[PIL.Image.Image], np.ndarray]

0 commit comments

Comments
 (0)