@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
80
80
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
81
81
time_shift_type (`str`, defaults to "exponential"):
82
82
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
83
+ stochastic_sampling (`bool`, defaults to False):
84
+ Whether to use stochastic sampling.
83
85
"""
84
86
85
87
_compatibles = []
@@ -101,6 +103,7 @@ def __init__(
101
103
use_exponential_sigmas : Optional [bool ] = False ,
102
104
use_beta_sigmas : Optional [bool ] = False ,
103
105
time_shift_type : str = "exponential" ,
106
+ stochastic_sampling : bool = False ,
104
107
):
105
108
if self .config .use_beta_sigmas and not is_scipy_available ():
106
109
raise ImportError ("Make sure to install scipy if you want to use beta sigmas." )
@@ -437,13 +440,25 @@ def step(
437
440
lower_mask = sigmas < per_token_sigmas [None ] - 1e-6
438
441
lower_sigmas = lower_mask * sigmas
439
442
lower_sigmas , _ = lower_sigmas .max (dim = 0 )
440
- dt = (per_token_sigmas - lower_sigmas )[..., None ]
443
+
444
+ current_sigma = per_token_sigmas [..., None ]
445
+ next_sigma = lower_sigmas [..., None ]
446
+ dt = current_sigma - next_sigma
441
447
else :
442
- sigma = self .sigmas [self .step_index ]
443
- sigma_next = self .sigmas [self .step_index + 1 ]
448
+ sigma_idx = self .step_index
449
+ sigma = self .sigmas [sigma_idx ]
450
+ sigma_next = self .sigmas [sigma_idx + 1 ]
451
+
452
+ current_sigma = sigma
453
+ next_sigma = sigma_next
444
454
dt = sigma_next - sigma
445
455
446
- prev_sample = sample + dt * model_output
456
+ if self .config .stochastic_sampling :
457
+ x0 = sample - current_sigma * model_output
458
+ noise = torch .randn_like (sample )
459
+ prev_sample = (1.0 - next_sigma ) * x0 + next_sigma * noise
460
+ else :
461
+ prev_sample = sample + dt * model_output
447
462
448
463
# upon completion increase step index by one
449
464
self ._step_index += 1
0 commit comments