Skip to content

Commit 879b41e

Browse files
committed
Change Dream scheduler so that self.timesteps has length num_inference_steps and fix shape errors
1 parent c78d823 commit 879b41e

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

src/diffusers/schedulers/scheduling_dream.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -318,19 +318,21 @@ def set_timesteps(
318318
self.num_inference_steps = num_inference_steps
319319

320320
if self.config.timestep_discretization == "linear":
321-
timesteps = torch.linspace(1.0, self.config.final_timestep, num_inference_steps + 1, device=device)
321+
timesteps = torch.linspace(1.0, self.config.final_timestep, num_inference_steps + 1)
322322
elif self.config.timestep_discretization == "cosine":
323323
timesteps = torch.linspace(1.0, self.config.final_timestep, num_inference_steps + 1)
324-
timesteps = torch.cos((torch.pi / 2) * (1.0 - timesteps)).to(device)
324+
timesteps = torch.cos((torch.pi / 2) * (1.0 - timesteps))
325325
else:
326326
raise ValueError(
327327
f"{self.config.timestep_discretization} is not a supported timestep discretization strategy. Current "
328328
f"supported strategies are `linear` and `cosine`."
329329
)
330-
self.timesteps = timesteps
330+
# Omit the final timestep so that len(self.timesteps) == num_inference_steps
331+
self.timesteps = timesteps[:-1].to(device=device)
331332

332333
# Now calculate the masking or noise schedule (alpha) values at the chosen timestep discretization
333-
self.alphas = self.t_to_alpha(timesteps=self.timesteps).to(device=device)
334+
# NOTE: the masking/alpha schedule is one element longer than self.timesteps (len num_inference_steps + 1)
335+
self.alphas = self.t_to_alpha(timesteps=timesteps).to(device=device)
334336

335337
# Allow overriding of specific sampling parameters (temperature, top_p, etc.)
336338
if temperature is None:
@@ -375,8 +377,6 @@ def step(
375377
# model_output shape: [B, L, V]
376378
# sample shape: [B, L] (sequence of discrete tokens)
377379
step_idx = self.index_for_timestep(timestep)
378-
t = self.timesteps[step_idx] # Current timestep
379-
s = self.timesteps[step_idx + 1] # Previous timestep (next-largest timestep not yet processed)
380380
temperature = self.temperatures[step_idx] if self.temperatures is not None else 1.0
381381
top_p = self.top_p_schedule[step_idx] if self.top_p_schedule is not None else None
382382
top_k = self.top_k_schedule[step_idx] if self.top_k_schedule is not None else None
@@ -433,14 +433,19 @@ def step(
433433

434434
if num_tokens_to_unmask > 0:
435435
if self.config.alg_temperature is None or self.config.alg_temperature == 0:
436-
_, unmask_index = torch.topk(full_confidence, num_tokens_to_unmask)
436+
_, unmask_indices = torch.topk(full_confidence, num_tokens_to_unmask)
437437
else:
438438
full_confidence = full_confidence / self.config.alg_temperature
439439
full_confidence = F.softmax(full_confidence, dim=-1)
440-
unmask_index = torch.multinomial(full_confidence, num_samples=num_tokens_to_unmask)
441-
442-
prev_sample = torch.zeros_like(sample, device=sample.device)
443-
prev_sample = torch.where(unmask_index, pred_original_sample, sample)
440+
unmask_indices = torch.multinomial(
441+
full_confidence, num_samples=num_tokens_to_unmask, generator=generator
442+
)
443+
unmask_indices = unmask_indices.to(sample.device)
444+
445+
row_indices = torch.arange(sample.size(0), device=sample.device).unsqueeze(1).expand_as(unmask_indices)
446+
prev_sample = sample.clone()
447+
# Unmask at the chosen indices with values from the pred_original_sample
448+
prev_sample[row_indices, unmask_indices] = pred_original_sample[row_indices, unmask_indices]
444449
else:
445450
# No tokens to unmask, so the sample should stay the same
446451
prev_sample = sample

0 commit comments

Comments
 (0)