File tree Expand file tree Collapse file tree 3 files changed +8
-3
lines changed Expand file tree Collapse file tree 3 files changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -906,7 +906,7 @@ def __call__(
906
906
)
907
907
908
908
# 6. Denoising loop
909
- with self .progress_bar (total = num_inference_steps ) as progress_bar :
909
+ with self .progress_bar (total = num_inference_steps ) as progress_bar , self . transformer . _cache_context () as cc :
910
910
for i , t in enumerate (timesteps ):
911
911
if self .interrupt :
912
912
continue
@@ -917,6 +917,7 @@ def __call__(
917
917
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
918
918
timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
919
919
920
+ cc .mark_state ("cond" )
920
921
noise_pred = self .transformer (
921
922
hidden_states = latents ,
922
923
timestep = timestep / 1000 ,
@@ -932,6 +933,8 @@ def __call__(
932
933
if do_true_cfg :
933
934
if negative_image_embeds is not None :
934
935
self ._joint_attention_kwargs ["ip_adapter_image_embeds" ] = negative_image_embeds
936
+
937
+ cc .mark_state ("uncond" )
935
938
neg_noise_pred = self .transformer (
936
939
hidden_states = latents ,
937
940
timestep = timestep / 1000 ,
Original file line number Diff line number Diff line change @@ -1061,7 +1061,7 @@ def __call__(
1061
1061
self ._num_timesteps = len (timesteps )
1062
1062
1063
1063
# 7. Denoising loop
1064
- with self .progress_bar (total = num_inference_steps ) as progress_bar :
1064
+ with self .progress_bar (total = num_inference_steps ) as progress_bar , self . transformer . _cache_context () as cc :
1065
1065
for i , t in enumerate (timesteps ):
1066
1066
if self .interrupt :
1067
1067
continue
@@ -1090,6 +1090,7 @@ def __call__(
1090
1090
timestep = t .expand (latent_model_input .shape [0 ]).unsqueeze (- 1 ).float ()
1091
1091
timestep = torch .min (timestep , (1 - conditioning_mask_model_input ) * 1000.0 )
1092
1092
1093
+ cc .mark_state ("cond_uncond" )
1093
1094
noise_pred = self .transformer (
1094
1095
hidden_states = latent_model_input ,
1095
1096
encoder_hidden_states = prompt_embeds ,
Original file line number Diff line number Diff line change @@ -771,7 +771,7 @@ def __call__(
771
771
)
772
772
773
773
# 7. Denoising loop
774
- with self .progress_bar (total = num_inference_steps ) as progress_bar :
774
+ with self .progress_bar (total = num_inference_steps ) as progress_bar , self . transformer . _cache_context () as cc :
775
775
for i , t in enumerate (timesteps ):
776
776
if self .interrupt :
777
777
continue
@@ -783,6 +783,7 @@ def __call__(
783
783
timestep = t .expand (latent_model_input .shape [0 ])
784
784
timestep = timestep .unsqueeze (- 1 ) * (1 - conditioning_mask )
785
785
786
+ cc .mark_state ("cond_uncond" )
786
787
noise_pred = self .transformer (
787
788
hidden_states = latent_model_input ,
788
789
encoder_hidden_states = prompt_embeds ,
You can’t perform that action at this time.
0 commit comments