Skip to content

Commit b61025a

Browse files
committed
Get Dream scheduler tests working
1 parent 1abd436 commit b61025a

File tree

1 file changed

+62
-37
lines changed

1 file changed

+62
-37
lines changed

src/diffusers/schedulers/scheduling_dream.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def create_schedule(
2121
schedule = None
2222
elif isinstance(schedule_params, float):
2323
# Interpret as a constant schedule for all timesteps
24-
schedule = torch.full(num_inference_steps, schedule_params)
24+
schedule = torch.full((num_inference_steps,), schedule_params)
2525
elif isinstance(schedule_params, (tuple, list)):
2626
# Interpret first and second elems as start and end points of a linear schedule
2727
schedule = torch.linspace(schedule_params[0], schedule_params[1], num_inference_steps)
@@ -79,7 +79,8 @@ def sample_tokens(
7979
) -> Tuple[torch.Tensor, torch.Tensor]:
8080
"""
8181
Samples from a sequence of logits of shape [..., vocab_size] and returns both the sampled sequence (as the second
82-
return elem) and the model probabilities for the chosen tokens (as the first return elem).
82+
return elem) and the model probabilities for the chosen tokens (as the first return elem) with the same shape as
83+
the leading (non-vocab-size) dims of logits.
8384
"""
8485
# logits shape: [B, L, V]
8586
if temperature > 0:
@@ -112,8 +113,8 @@ def sample_tokens(
112113
if margin_confidence:
113114
sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
114115
# Extract top1 and top2 probabilities
115-
top1_probs = sorted_probs[:, 0]
116-
top2_probs = sorted_probs[:, 1]
116+
top1_probs = sorted_probs[..., 0]
117+
top2_probs = sorted_probs[..., 1]
117118
# Calculate confidence as top1 - top2
118119
confidence = top1_probs - top2_probs
119120

@@ -225,6 +226,37 @@ def __init__(
225226

226227
self.init_noise_sigma = 1.0
227228

229+
def t_to_alpha(
230+
self,
231+
timesteps: Optional[torch.Tensor] = None,
232+
masking_schedule: Optional[str] = None,
233+
) -> torch.Tensor:
234+
"""
235+
Calculates the masking (alpha) schedule as a function of the provided timesteps. The timesteps do not
236+
necessarily have to match those of self.timesteps.
237+
"""
238+
if timesteps is None:
239+
if self.timesteps is not None:
240+
timesteps = self.timesteps
241+
else:
242+
raise ValueError("Since `self.timesteps` is not set, `timesteps` cannot also be `None`")
243+
if masking_schedule is None:
244+
masking_schedule = self.config.masking_schedule
245+
246+
if self.config.masking_schedule == "linear":
247+
alphas = 1.0 - timesteps
248+
elif self.config.masking_schedule == "cosine":
249+
alphas = 1.0 - torch.cos((torch.pi / 2) * (1.0 - timesteps))
250+
elif self.config.masking_schedule == "polynomial":
251+
alphas = 1.0 - torch.pow(timesteps, self.config.polynomial_exp)
252+
else:
253+
raise ValueError(
254+
f"{self.config.masking_schedule} is not a supported masking schedule. Currently supported schedules "
255+
f"are `linear`, `cosine`, and `polynomial`."
256+
)
257+
258+
return alphas
259+
228260
def index_for_timestep(self, timestep, schedule_timesteps=None):
229261
if schedule_timesteps is None:
230262
schedule_timesteps = self.timesteps
@@ -283,10 +315,12 @@ def set_timesteps(
283315
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
284316
"""
285317

318+
self.num_inference_steps = num_inference_steps
319+
286320
if self.config.timestep_discretization == "linear":
287321
timesteps = torch.linspace(1.0, self.config.final_timestep, num_inference_steps + 1, device=device)
288322
elif self.config.timestep_discretization == "cosine":
289-
timesteps = torch.linspace(self.config.final_timestep, 1.0, num_inference_steps + 1)
323+
timesteps = torch.linspace(1.0, self.config.final_timestep, num_inference_steps + 1)
290324
timesteps = torch.cos((torch.pi / 2) * (1.0 - timesteps)).to(device)
291325
else:
292326
raise ValueError(
@@ -296,18 +330,7 @@ def set_timesteps(
296330
self.timesteps = timesteps
297331

298332
# Now calculate the masking or noise schedule (alpha) values at the chosen timestep discretization
299-
if self.config.masking_schedule == "linear":
300-
alphas = 1.0 - self.timesteps
301-
elif self.config.masking_schedule == "cosine":
302-
alphas = 1.0 - torch.cos((torch.pi / 2) * (1.0 - self.timesteps))
303-
elif self.config.masking_schedule == "polynomial":
304-
alphas = 1.0 - torch.pow(self.timesteps, self.config.polynomial_exp)
305-
else:
306-
raise ValueError(
307-
f"{self.config.masking_schedule} is not a supported masking schedule. Currently supported schedules "
308-
f"are `linear`, `cosine`, and `polynomial`."
309-
)
310-
self.alphas = alphas.to(device=device)
333+
self.alphas = self.t_to_alpha(timesteps=self.timesteps).to(device=device)
311334

312335
# Allow overriding of specific sampling parameters (temperature, top_p, etc.)
313336
if temperature is None:
@@ -364,7 +387,7 @@ def step(
364387
# Right shift the logits from the model
365388
# Dream models are trained to predict at right-shifted positions, analogous to an autoregressive model,
366389
# so we also need to shift the inputs at inference time
367-
model_output = torch.cat(model_output[:, :1], model_output[:, :-1], dim=1)
390+
model_output = torch.cat([model_output[:, :1], model_output[:, :-1]], dim=1)
368391

369392
# Probability of unmasking each token at time t
370393
unmask_prob = (self.alphas[step_idx + 1] - self.alphas[step_idx]) / (1 - self.alphas[step_idx])
@@ -418,6 +441,9 @@ def step(
418441

419442
prev_sample = torch.zeros_like(sample, device=sample.device)
420443
prev_sample = torch.where(unmask_index, pred_original_sample, sample)
444+
else:
445+
# No tokens to unmask, so the sample should stay the same
446+
prev_sample = sample
421447

422448
# TODO: do we need to shift the tokens again at the end???
423449
if not return_dict:
@@ -428,35 +454,34 @@ def step(
428454
def add_noise(
429455
self,
430456
original_samples: torch.Tensor,
457+
noise: Optional[torch.Tensor],
431458
timesteps: torch.Tensor,
432459
generator: Optional[torch.Generator] = None,
433460
) -> torch.Tensor:
434-
# For each batch instance i with timestep t_i, mask each position independently with prob 1 - alphas[t_i]
435-
# original_samples shape: [B, L]
436-
# Make sure alphas and timesteps have the same device and dtype as original_samples
437-
alphas = self.alphas.to(device=original_samples.device, dtype=original_samples.dtype)
438-
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
439-
# mps does not support float64
440-
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
441-
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
442-
else:
443-
schedule_timesteps = self.timesteps.to(original_samples.device)
444-
timesteps = timesteps.to(original_samples.device)
461+
"""
462+
Applies a masked (discrete) diffusion forward process where batch instance i with timestep t_i is masked at
463+
each position independently with probability 1 - alpha(t_i). Any (continuous) time in [0, 1] can be used, not
464+
just those timesteps that would be in self.timesteps, as masked diffusion models are usually trained in
465+
continuous time.
445466
446-
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
467+
Note that `noise` here should be drawn from a uniform distribution over [0, 1], rather than a Gaussian
468+
distribution as is normal for continuous-space diffusion models.
469+
"""
470+
# original_samples shape: [B, L]
471+
alphas = self.t_to_alpha(timesteps=timesteps)
472+
alphas = alphas.to(device=original_samples.device, dtype=original_samples.dtype)
447473

448-
mask_probs = 1.0 - alphas[step_indices].flatten()
449-
while len(mask_probs).shape < len(original_samples.shape):
450-
mask_probs.unsqueeze(-1)
474+
mask_probs = 1.0 - alphas
475+
while len(mask_probs.shape) < len(original_samples.shape):
476+
mask_probs = mask_probs.unsqueeze(-1)
451477

452-
mask_indices = (
453-
torch.rand(
478+
if noise is None:
479+
noise = torch.rand(
454480
original_samples.shape,
455481
device=generator.device if generator is not None else original_samples.device,
456482
generator=generator,
457483
).to(original_samples.device)
458-
< mask_probs
459-
)
484+
mask_indices = noise < mask_probs
460485

461486
masked_samples = original_samples.clone()
462487
masked_samples[mask_indices] = self.config.mask_token_id

0 commit comments

Comments
 (0)