Skip to content

Commit 1f33ca2

Browse files
committed
support flux, ltx i2v, ltx condition
1 parent 41b0c47 commit 1f33ca2

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,7 @@ def __call__(
906906
)
907907

908908
# 6. Denoising loop
909-
with self.progress_bar(total=num_inference_steps) as progress_bar:
909+
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
910910
for i, t in enumerate(timesteps):
911911
if self.interrupt:
912912
continue
@@ -917,6 +917,7 @@ def __call__(
917917
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
918918
timestep = t.expand(latents.shape[0]).to(latents.dtype)
919919

920+
cc.mark_state("cond")
920921
noise_pred = self.transformer(
921922
hidden_states=latents,
922923
timestep=timestep / 1000,
@@ -932,6 +933,8 @@ def __call__(
932933
if do_true_cfg:
933934
if negative_image_embeds is not None:
934935
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
936+
937+
cc.mark_state("uncond")
935938
neg_noise_pred = self.transformer(
936939
hidden_states=latents,
937940
timestep=timestep / 1000,

src/diffusers/pipelines/ltx/pipeline_ltx_condition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ def __call__(
10611061
self._num_timesteps = len(timesteps)
10621062

10631063
# 7. Denoising loop
1064-
with self.progress_bar(total=num_inference_steps) as progress_bar:
1064+
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
10651065
for i, t in enumerate(timesteps):
10661066
if self.interrupt:
10671067
continue
@@ -1090,6 +1090,7 @@ def __call__(
10901090
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
10911091
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
10921092

1093+
cc.mark_state("cond_uncond")
10931094
noise_pred = self.transformer(
10941095
hidden_states=latent_model_input,
10951096
encoder_hidden_states=prompt_embeds,

src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def __call__(
771771
)
772772

773773
# 7. Denoising loop
774-
with self.progress_bar(total=num_inference_steps) as progress_bar:
774+
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
775775
for i, t in enumerate(timesteps):
776776
if self.interrupt:
777777
continue
@@ -783,6 +783,7 @@ def __call__(
783783
timestep = t.expand(latent_model_input.shape[0])
784784
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
785785

786+
cc.mark_state("cond_uncond")
786787
noise_pred = self.transformer(
787788
hidden_states=latent_model_input,
788789
encoder_hidden_states=prompt_embeds,

0 commit comments

Comments
 (0)