Skip to content

Commit 3e99b56

Browse files
ParagEkbotehlky
andauthored
Extend Support for callback_on_step_end for AuraFlow and LuminaText2Img Pipelines (#10746)
* Add support for callback_on_step_end for AuraFlowPipeline and LuminaText2ImgPipeline. * Apply the suggestions from code review for lumina and auraflow Co-authored-by: hlky <hlky@hlky.ac> * Update missing inputs and imports. * Add input field. * Apply suggestions from code review-2 Co-authored-by: hlky <hlky@hlky.ac> * Apply the suggestions from review for unused imports. Co-authored-by: hlky <hlky@hlky.ac> * make style. * Update pipeline_aura_flow.py * Update pipeline_lumina.py * Update pipeline_lumina.py * Update pipeline_aura_flow.py * Update pipeline_lumina.py --------- Co-authored-by: hlky <hlky@hlky.ac>
1 parent 952b913 commit 3e99b56

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import inspect
15-
from typing import List, Optional, Tuple, Union
15+
from typing import Callable, Dict, List, Optional, Tuple, Union
1616

1717
import torch
1818
from transformers import T5Tokenizer, UMT5EncoderModel
1919

20+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2021
from ...image_processor import VaeImageProcessor
2122
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
2223
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
@@ -131,6 +132,10 @@ class AuraFlowPipeline(DiffusionPipeline):
131132

132133
_optional_components = []
133134
model_cpu_offload_seq = "text_encoder->transformer->vae"
135+
_callback_tensor_inputs = [
136+
"latents",
137+
"prompt_embeds",
138+
]
134139

135140
def __init__(
136141
self,
@@ -159,12 +164,19 @@ def check_inputs(
159164
negative_prompt_embeds=None,
160165
prompt_attention_mask=None,
161166
negative_prompt_attention_mask=None,
167+
callback_on_step_end_tensor_inputs=None,
162168
):
163169
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
164170
raise ValueError(
165171
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
166172
)
167173

174+
if callback_on_step_end_tensor_inputs is not None and not all(
175+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
176+
):
177+
raise ValueError(
178+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
179+
)
168180
if prompt is not None and prompt_embeds is not None:
169181
raise ValueError(
170182
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -387,6 +399,14 @@ def upcast_vae(self):
387399
self.vae.decoder.conv_in.to(dtype)
388400
self.vae.decoder.mid_block.to(dtype)
389401

402+
@property
403+
def guidance_scale(self):
404+
return self._guidance_scale
405+
406+
@property
407+
def num_timesteps(self):
408+
return self._num_timesteps
409+
390410
@torch.no_grad()
391411
@replace_example_docstring(EXAMPLE_DOC_STRING)
392412
def __call__(
@@ -408,6 +428,10 @@ def __call__(
408428
max_sequence_length: int = 256,
409429
output_type: Optional[str] = "pil",
410430
return_dict: bool = True,
431+
callback_on_step_end: Optional[
432+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
433+
] = None,
434+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
411435
) -> Union[ImagePipelineOutput, Tuple]:
412436
r"""
413437
Function invoked when calling the pipeline for generation.
@@ -462,6 +486,15 @@ def __call__(
462486
return_dict (`bool`, *optional*, defaults to `True`):
463487
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
464488
of a plain tuple.
489+
callback_on_step_end (`Callable`, *optional*):
490+
A function that calls at the end of each denoising steps during the inference. The function is called
491+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
492+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
493+
`callback_on_step_end_tensor_inputs`.
494+
callback_on_step_end_tensor_inputs (`List`, *optional*):
495+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
496+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
497+
`._callback_tensor_inputs` attribute of your pipeline class.
465498
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
466499
467500
Examples:
@@ -483,8 +516,11 @@ def __call__(
483516
negative_prompt_embeds,
484517
prompt_attention_mask,
485518
negative_prompt_attention_mask,
519+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
486520
)
487521

522+
self._guidance_scale = guidance_scale
523+
488524
# 2. Determine batch size.
489525
if prompt is not None and isinstance(prompt, str):
490526
batch_size = 1
@@ -541,6 +577,7 @@ def __call__(
541577

542578
# 6. Denoising loop
543579
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
580+
self._num_timesteps = len(timesteps)
544581
with self.progress_bar(total=num_inference_steps) as progress_bar:
545582
for i, t in enumerate(timesteps):
546583
# expand the latents if we are doing classifier free guidance
@@ -567,6 +604,15 @@ def __call__(
567604
# compute the previous noisy sample x_t -> x_t-1
568605
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
569606

607+
if callback_on_step_end is not None:
608+
callback_kwargs = {}
609+
for k in callback_on_step_end_tensor_inputs:
610+
callback_kwargs[k] = locals()[k]
611+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
612+
613+
latents = callback_outputs.pop("latents", latents)
614+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
615+
570616
# call the callback, if provided
571617
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
572618
progress_bar.update()

src/diffusers/pipelines/lumina/pipeline_lumina.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
import math
1818
import re
1919
import urllib.parse as ul
20-
from typing import List, Optional, Tuple, Union
20+
from typing import Callable, Dict, List, Optional, Tuple, Union
2121

2222
import torch
2323
from transformers import AutoModel, AutoTokenizer
2424

25+
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2526
from ...image_processor import VaeImageProcessor
2627
from ...models import AutoencoderKL
2728
from ...models.embeddings import get_2d_rotary_pos_embed_lumina
@@ -174,6 +175,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
174175

175176
_optional_components = []
176177
model_cpu_offload_seq = "text_encoder->transformer->vae"
178+
_callback_tensor_inputs = [
179+
"latents",
180+
"prompt_embeds",
181+
]
177182

178183
def __init__(
179184
self,
@@ -395,12 +400,20 @@ def check_inputs(
395400
negative_prompt_embeds=None,
396401
prompt_attention_mask=None,
397402
negative_prompt_attention_mask=None,
403+
callback_on_step_end_tensor_inputs=None,
398404
):
399405
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
400406
raise ValueError(
401407
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
402408
)
403409

410+
if callback_on_step_end_tensor_inputs is not None and not all(
411+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
412+
):
413+
raise ValueError(
414+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
415+
)
416+
404417
if prompt is not None and prompt_embeds is not None:
405418
raise ValueError(
406419
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -644,6 +657,10 @@ def __call__(
644657
max_sequence_length: int = 256,
645658
scaling_watershed: Optional[float] = 1.0,
646659
proportional_attn: Optional[bool] = True,
660+
callback_on_step_end: Optional[
661+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
662+
] = None,
663+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
647664
) -> Union[ImagePipelineOutput, Tuple]:
648665
"""
649666
Function invoked when calling the pipeline for generation.
@@ -735,7 +752,11 @@ def __call__(
735752
negative_prompt_embeds=negative_prompt_embeds,
736753
prompt_attention_mask=prompt_attention_mask,
737754
negative_prompt_attention_mask=negative_prompt_attention_mask,
755+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
738756
)
757+
758+
self._guidance_scale = guidance_scale
759+
739760
cross_attention_kwargs = {}
740761

741762
# 2. Define call parameters
@@ -797,6 +818,8 @@ def __call__(
797818
latents,
798819
)
799820

821+
self._num_timesteps = len(timesteps)
822+
800823
# 6. Denoising loop
801824
with self.progress_bar(total=num_inference_steps) as progress_bar:
802825
for i, t in enumerate(timesteps):
@@ -886,6 +909,15 @@ def __call__(
886909

887910
progress_bar.update()
888911

912+
if callback_on_step_end is not None:
913+
callback_kwargs = {}
914+
for k in callback_on_step_end_tensor_inputs:
915+
callback_kwargs[k] = locals()[k]
916+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
917+
918+
latents = callback_outputs.pop("latents", latents)
919+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
920+
889921
if XLA_AVAILABLE:
890922
xm.mark_step()
891923

0 commit comments

Comments
 (0)