Skip to content

Commit 6ab62c7

Browse files
apolinariogithub-actions[bot]yiyixuxu
authored
Add stochastic sampling to FlowMatchEulerDiscreteScheduler (#11369)
* Add stochastic sampling to FlowMatchEulerDiscreteScheduler This PR adds stochastic sampling to FlowMatchEulerDiscreteScheduler based on Lightricks/LTX-Video@b1aeddd ltx_video/schedulers/rf.py * Apply style fixes * Use config value directly * Apply style fixes * Swap order * Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py Co-authored-by: YiYi Xu <yixu310@gmail.com> --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent f59df3b commit 6ab62c7

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
8080
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
8181
time_shift_type (`str`, defaults to "exponential"):
8282
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
83+
stochastic_sampling (`bool`, defaults to False):
84+
Whether to use stochastic sampling.
8385
"""
8486

8587
_compatibles = []
@@ -101,6 +103,7 @@ def __init__(
101103
use_exponential_sigmas: Optional[bool] = False,
102104
use_beta_sigmas: Optional[bool] = False,
103105
time_shift_type: str = "exponential",
106+
stochastic_sampling: bool = False,
104107
):
105108
if self.config.use_beta_sigmas and not is_scipy_available():
106109
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -437,13 +440,25 @@ def step(
437440
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
438441
lower_sigmas = lower_mask * sigmas
439442
lower_sigmas, _ = lower_sigmas.max(dim=0)
440-
dt = (per_token_sigmas - lower_sigmas)[..., None]
443+
444+
current_sigma = per_token_sigmas[..., None]
445+
next_sigma = lower_sigmas[..., None]
446+
dt = current_sigma - next_sigma
441447
else:
442-
sigma = self.sigmas[self.step_index]
443-
sigma_next = self.sigmas[self.step_index + 1]
448+
sigma_idx = self.step_index
449+
sigma = self.sigmas[sigma_idx]
450+
sigma_next = self.sigmas[sigma_idx + 1]
451+
452+
current_sigma = sigma
453+
next_sigma = sigma_next
444454
dt = sigma_next - sigma
445455

446-
prev_sample = sample + dt * model_output
456+
if self.config.stochastic_sampling:
457+
x0 = sample - current_sigma * model_output
458+
noise = torch.randn_like(sample)
459+
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
460+
else:
461+
prev_sample = sample + dt * model_output
447462

448463
# upon completion increase step index by one
449464
self._step_index += 1

0 commit comments

Comments
 (0)