Skip to content

Commit 45465d4

Browse files
committed
more
1 parent cd71035 commit 45465d4

File tree

1 file changed

+215
-0
lines changed

1 file changed

+215
-0
lines changed
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, List, Tuple
16+
17+
import torch
18+
19+
from ...models import FluxTransformer2DModel
20+
from ...schedulers import FlowMatchEulerDiscreteScheduler
21+
from ...utils import logging
22+
from ..modular_pipeline import (
23+
BlockState,
24+
LoopSequentialPipelineBlocks,
25+
PipelineBlock,
26+
PipelineState,
27+
)
28+
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
29+
from .modular_pipeline import FluxModularPipeline
30+
31+
32+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33+
34+
35+
class FluxLoopDenoiser(PipelineBlock):
36+
model_name = "flux"
37+
38+
@property
39+
def expected_components(self) -> List[ComponentSpec]:
40+
return [ComponentSpec("transformer", FluxTransformer2DModel)]
41+
42+
@property
43+
def description(self) -> str:
44+
return (
45+
"Step within the denoising loop that denoise the latents. "
46+
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
47+
"object (e.g. `FluxDenoiseLoopWrapper`)"
48+
)
49+
50+
@property
51+
def inputs(self) -> List[Tuple[str, Any]]:
52+
return [
53+
InputParam("attention_kwargs"),
54+
]
55+
56+
@property
57+
def intermediate_inputs(self) -> List[str]:
58+
return [
59+
InputParam(
60+
"latents",
61+
required=True,
62+
type_hint=torch.Tensor,
63+
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
64+
),
65+
InputParam(
66+
"num_inference_steps",
67+
required=True,
68+
type_hint=int,
69+
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
70+
),
71+
# TODO: guidance
72+
]
73+
74+
@torch.no_grad()
75+
def __call__(
76+
self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
77+
) -> PipelineState:
78+
noise_pred = components.transformer(
79+
hidden_states=block_state.latents,
80+
timestep=t.flatten() / 1000,
81+
encoder_hidden_states=block_state.prompt_embeds,
82+
pooled_projections=block_state.pooled_prompt_embeds,
83+
attention_kwargs=block_state.attention_kwargs,
84+
txt_ids=block_state.text_ids,
85+
img_ids=block_state.latent_image_ids,
86+
return_dict=False,
87+
)[0]
88+
block_state.noise_pred = noise_pred
89+
90+
return components, block_state
91+
92+
93+
class FluxLoopAfterDenoiser(PipelineBlock):
94+
model_name = "flux"
95+
96+
@property
97+
def expected_components(self) -> List[ComponentSpec]:
98+
return [
99+
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
100+
]
101+
102+
@property
103+
def description(self) -> str:
104+
return (
105+
"step within the denoising loop that update the latents. "
106+
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
107+
"object (e.g. `FluxDenoiseLoopWrapper`)"
108+
)
109+
110+
@property
111+
def inputs(self) -> List[Tuple[str, Any]]:
112+
return []
113+
114+
@property
115+
def intermediate_inputs(self) -> List[str]:
116+
return [
117+
InputParam("generator"),
118+
]
119+
120+
@property
121+
def intermediate_outputs(self) -> List[OutputParam]:
122+
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
123+
124+
@torch.no_grad()
125+
def __call__(self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
126+
# Perform scheduler step using the predicted output
127+
latents_dtype = block_state.latents.dtype
128+
block_state.latents = components.scheduler.step(
129+
block_state.noise_pred,
130+
t,
131+
block_state.latents,
132+
**block_state.scheduler_step_kwargs,
133+
return_dict=False,
134+
)[0]
135+
136+
if block_state.latents.dtype != latents_dtype:
137+
block_state.latents = block_state.latents.to(latents_dtype)
138+
139+
return components, block_state
140+
141+
142+
class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
143+
model_name = "flux"
144+
145+
@property
146+
def description(self) -> str:
147+
return (
148+
"Pipeline block that iteratively denoise the latents over `timesteps`. "
149+
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
150+
)
151+
152+
@property
153+
def loop_expected_components(self) -> List[ComponentSpec]:
154+
return [
155+
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
156+
ComponentSpec("transformer", FluxTransformer2DModel),
157+
]
158+
159+
@property
160+
def loop_intermediate_inputs(self) -> List[InputParam]:
161+
return [
162+
InputParam(
163+
"timesteps",
164+
required=True,
165+
type_hint=torch.Tensor,
166+
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
167+
),
168+
InputParam(
169+
"num_inference_steps",
170+
required=True,
171+
type_hint=int,
172+
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
173+
),
174+
]
175+
176+
@torch.no_grad()
177+
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
178+
block_state = self.get_block_state(state)
179+
180+
block_state.num_warmup_steps = max(
181+
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
182+
)
183+
# We set the index here to remove DtoH sync, helpful especially during compilation.
184+
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
185+
components.scheduler.set_begin_index(0)
186+
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
187+
for i, t in enumerate(block_state.timesteps):
188+
components, block_state = self.loop_step(components, block_state, i=i, t=t)
189+
if i == len(block_state.timesteps) - 1 or (
190+
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
191+
):
192+
progress_bar.update()
193+
194+
self.set_block_state(state, block_state)
195+
196+
return components, state
197+
198+
199+
class FluxDenoiseStep(FluxDenoiseLoopWrapper):
200+
block_classes = [
201+
FluxLoopDenoiser,
202+
FluxLoopAfterDenoiser,
203+
]
204+
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
205+
206+
@property
207+
def description(self) -> str:
208+
return (
209+
"Denoise step that iteratively denoise the latents. \n"
210+
"Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
211+
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
212+
" - `FluxLoopDenoiser`\n"
213+
" - `FluxLoopAfterDenoiser`\n"
214+
"This block supports text2image tasks."
215+
)

0 commit comments

Comments
 (0)