12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import inspect
15
- from typing import List , Optional , Tuple , Union
15
+ from typing import Callable , Dict , List , Optional , Tuple , Union
16
16
17
17
import torch
18
18
from transformers import T5Tokenizer , UMT5EncoderModel
19
19
20
+ from ...callbacks import MultiPipelineCallbacks , PipelineCallback
20
21
from ...image_processor import VaeImageProcessor
21
22
from ...models import AuraFlowTransformer2DModel , AutoencoderKL
22
23
from ...models .attention_processor import AttnProcessor2_0 , FusedAttnProcessor2_0 , XFormersAttnProcessor
@@ -131,6 +132,10 @@ class AuraFlowPipeline(DiffusionPipeline):
131
132
132
133
_optional_components = []
133
134
model_cpu_offload_seq = "text_encoder->transformer->vae"
135
+ _callback_tensor_inputs = [
136
+ "latents" ,
137
+ "prompt_embeds" ,
138
+ ]
134
139
135
140
def __init__ (
136
141
self ,
@@ -159,12 +164,19 @@ def check_inputs(
159
164
negative_prompt_embeds = None ,
160
165
prompt_attention_mask = None ,
161
166
negative_prompt_attention_mask = None ,
167
+ callback_on_step_end_tensor_inputs = None ,
162
168
):
163
169
if height % (self .vae_scale_factor * 2 ) != 0 or width % (self .vae_scale_factor * 2 ) != 0 :
164
170
raise ValueError (
165
171
f"`height` and `width` have to be divisible by { self .vae_scale_factor * 2 } but are { height } and { width } ."
166
172
)
167
173
174
+ if callback_on_step_end_tensor_inputs is not None and not all (
175
+ k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
176
+ ):
177
+ raise ValueError (
178
+ f"`callback_on_step_end_tensor_inputs` has to be in { self ._callback_tensor_inputs } , but found { [k for k in callback_on_step_end_tensor_inputs if k not in self ._callback_tensor_inputs ]} "
179
+ )
168
180
if prompt is not None and prompt_embeds is not None :
169
181
raise ValueError (
170
182
f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
@@ -387,6 +399,14 @@ def upcast_vae(self):
387
399
self .vae .decoder .conv_in .to (dtype )
388
400
self .vae .decoder .mid_block .to (dtype )
389
401
402
+ @property
403
+ def guidance_scale (self ):
404
+ return self ._guidance_scale
405
+
406
+ @property
407
+ def num_timesteps (self ):
408
+ return self ._num_timesteps
409
+
390
410
@torch .no_grad ()
391
411
@replace_example_docstring (EXAMPLE_DOC_STRING )
392
412
def __call__ (
@@ -408,6 +428,10 @@ def __call__(
408
428
max_sequence_length : int = 256 ,
409
429
output_type : Optional [str ] = "pil" ,
410
430
return_dict : bool = True ,
431
+ callback_on_step_end : Optional [
432
+ Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
433
+ ] = None ,
434
+ callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
411
435
) -> Union [ImagePipelineOutput , Tuple ]:
412
436
r"""
413
437
Function invoked when calling the pipeline for generation.
@@ -462,6 +486,15 @@ def __call__(
462
486
return_dict (`bool`, *optional*, defaults to `True`):
463
487
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
464
488
of a plain tuple.
489
+ callback_on_step_end (`Callable`, *optional*):
490
+ A function that calls at the end of each denoising steps during the inference. The function is called
491
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
492
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
493
+ `callback_on_step_end_tensor_inputs`.
494
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
495
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
496
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
497
+ `._callback_tensor_inputs` attribute of your pipeline class.
465
498
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
466
499
467
500
Examples:
@@ -483,8 +516,11 @@ def __call__(
483
516
negative_prompt_embeds ,
484
517
prompt_attention_mask ,
485
518
negative_prompt_attention_mask ,
519
+ callback_on_step_end_tensor_inputs = callback_on_step_end_tensor_inputs ,
486
520
)
487
521
522
+ self ._guidance_scale = guidance_scale
523
+
488
524
# 2. Determine batch size.
489
525
if prompt is not None and isinstance (prompt , str ):
490
526
batch_size = 1
@@ -541,6 +577,7 @@ def __call__(
541
577
542
578
# 6. Denoising loop
543
579
num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
580
+ self ._num_timesteps = len (timesteps )
544
581
with self .progress_bar (total = num_inference_steps ) as progress_bar :
545
582
for i , t in enumerate (timesteps ):
546
583
# expand the latents if we are doing classifier free guidance
@@ -567,6 +604,15 @@ def __call__(
567
604
# compute the previous noisy sample x_t -> x_t-1
568
605
latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
569
606
607
+ if callback_on_step_end is not None :
608
+ callback_kwargs = {}
609
+ for k in callback_on_step_end_tensor_inputs :
610
+ callback_kwargs [k ] = locals ()[k ]
611
+ callback_outputs = callback_on_step_end (self , i , t , callback_kwargs )
612
+
613
+ latents = callback_outputs .pop ("latents" , latents )
614
+ prompt_embeds = callback_outputs .pop ("prompt_embeds" , prompt_embeds )
615
+
570
616
# call the callback, if provided
571
617
if i == len (timesteps ) - 1 or ((i + 1 ) > num_warmup_steps and (i + 1 ) % self .scheduler .order == 0 ):
572
618
progress_bar .update ()
0 commit comments