@@ -21,7 +21,7 @@ def create_schedule(
2121 schedule = None
2222 elif isinstance (schedule_params , float ):
2323 # Interpret as a constant schedule for all timesteps
24- schedule = torch .full (num_inference_steps , schedule_params )
24+ schedule = torch .full (( num_inference_steps ,) , schedule_params )
2525 elif isinstance (schedule_params , (tuple , list )):
2626 # Interpret first and second elems as start and end points of a linear schedule
2727 schedule = torch .linspace (schedule_params [0 ], schedule_params [1 ], num_inference_steps )
@@ -79,7 +79,8 @@ def sample_tokens(
7979 ) -> Tuple [torch .Tensor , torch .Tensor ]:
8080 """
8181 Samples from a sequence of logits of shape [..., vocab_size] and returns both the sampled sequence (as the second
82- return elem) and the model probabilities for the chosen tokens (as the first return elem).
82+ return elem) and the model probabilities for the chosen tokens (as the first return elem) with the same shape as
83+ the leading (non-vocab-size) dims of logits.
8384 """
8485 # logits shape: [B, L, V]
8586 if temperature > 0 :
@@ -112,8 +113,8 @@ def sample_tokens(
112113 if margin_confidence :
113114 sorted_probs , _ = torch .sort (probs , dim = - 1 , descending = True )
114115 # Extract top1 and top2 probabilities
115- top1_probs = sorted_probs [: , 0 ]
116- top2_probs = sorted_probs [: , 1 ]
116+ top1_probs = sorted_probs [... , 0 ]
117+ top2_probs = sorted_probs [... , 1 ]
117118 # Calculate confidence as top1 - top2
118119 confidence = top1_probs - top2_probs
119120
@@ -225,6 +226,37 @@ def __init__(
225226
226227 self .init_noise_sigma = 1.0
227228
229+ def t_to_alpha (
230+ self ,
231+ timesteps : Optional [torch .Tensor ] = None ,
232+ masking_schedule : Optional [str ] = None ,
233+ ) -> torch .Tensor :
234+ """
235+ Calculates the masking (alpha) schedule as a function of the provided timesteps. The timesteps do not
236+ necessarily have to match those of self.timesteps.
237+ """
238+ if timesteps is None :
239+ if self .timesteps is not None :
240+ timesteps = self .timesteps
241+ else :
242+ raise ValueError ("Since `self.timesteps` is not set, `timesteps` cannot also be `None`" )
243+ if masking_schedule is None :
244+ masking_schedule = self .config .masking_schedule
245+
246+ if self .config .masking_schedule == "linear" :
247+ alphas = 1.0 - timesteps
248+ elif self .config .masking_schedule == "cosine" :
249+ alphas = 1.0 - torch .cos ((torch .pi / 2 ) * (1.0 - timesteps ))
250+ elif self .config .masking_schedule == "polynomial" :
251+ alphas = 1.0 - torch .pow (timesteps , self .config .polynomial_exp )
252+ else :
253+ raise ValueError (
254+ f"{ self .config .masking_schedule } is not a supported masking schedule. Currently supported schedules "
255+ f"are `linear`, `cosine`, and `polynomial`."
256+ )
257+
258+ return alphas
259+
228260 def index_for_timestep (self , timestep , schedule_timesteps = None ):
229261 if schedule_timesteps is None :
230262 schedule_timesteps = self .timesteps
@@ -283,10 +315,12 @@ def set_timesteps(
283315 The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
284316 """
285317
318+ self .num_inference_steps = num_inference_steps
319+
286320 if self .config .timestep_discretization == "linear" :
287321 timesteps = torch .linspace (1.0 , self .config .final_timestep , num_inference_steps + 1 , device = device )
288322 elif self .config .timestep_discretization == "cosine" :
289- timesteps = torch .linspace (self .config .final_timestep , 1.0 , num_inference_steps + 1 )
323+ timesteps = torch .linspace (1.0 , self .config .final_timestep , num_inference_steps + 1 )
290324 timesteps = torch .cos ((torch .pi / 2 ) * (1.0 - timesteps )).to (device )
291325 else :
292326 raise ValueError (
@@ -296,18 +330,7 @@ def set_timesteps(
296330 self .timesteps = timesteps
297331
298332 # Now calculate the masking or noise schedule (alpha) values at the chosen timestep discretization
299- if self .config .masking_schedule == "linear" :
300- alphas = 1.0 - self .timesteps
301- elif self .config .masking_schedule == "cosine" :
302- alphas = 1.0 - torch .cos ((torch .pi / 2 ) * (1.0 - self .timesteps ))
303- elif self .config .masking_schedule == "polynomial" :
304- alphas = 1.0 - torch .pow (self .timesteps , self .config .polynomial_exp )
305- else :
306- raise ValueError (
307- f"{ self .config .masking_schedule } is not a supported masking schedule. Currently supported schedules "
308- f"are `linear`, `cosine`, and `polynomial`."
309- )
310- self .alphas = alphas .to (device = device )
333+ self .alphas = self .t_to_alpha (timesteps = self .timesteps ).to (device = device )
311334
312335 # Allow overriding of specific sampling parameters (temperature, top_p, etc.)
313336 if temperature is None :
@@ -364,7 +387,7 @@ def step(
364387 # Right shift the logits from the model
365388 # Dream models are trained to predict at right-shifted positions, analogous to an autoregressive model,
366389 # so we also need to shift the inputs at inference time
367- model_output = torch .cat (model_output [:, :1 ], model_output [:, :- 1 ], dim = 1 )
390+ model_output = torch .cat ([ model_output [:, :1 ], model_output [:, :- 1 ] ], dim = 1 )
368391
369392 # Probability of unmasking each token at time t
370393 unmask_prob = (self .alphas [step_idx + 1 ] - self .alphas [step_idx ]) / (1 - self .alphas [step_idx ])
@@ -418,6 +441,9 @@ def step(
418441
419442 prev_sample = torch .zeros_like (sample , device = sample .device )
420443 prev_sample = torch .where (unmask_index , pred_original_sample , sample )
444+ else :
445+ # No tokens to unmask, so the sample should stay the same
446+ prev_sample = sample
421447
422448 # TODO: do we need to shift the tokens again at the end???
423449 if not return_dict :
@@ -428,35 +454,34 @@ def step(
428454 def add_noise (
429455 self ,
430456 original_samples : torch .Tensor ,
457+ noise : Optional [torch .Tensor ],
431458 timesteps : torch .Tensor ,
432459 generator : Optional [torch .Generator ] = None ,
433460 ) -> torch .Tensor :
434- # For each batch instance i with timestep t_i, mask each position independently with prob 1 - alphas[t_i]
435- # original_samples shape: [B, L]
436- # Make sure alphas and timesteps have the same device and dtype as original_samples
437- alphas = self .alphas .to (device = original_samples .device , dtype = original_samples .dtype )
438- if original_samples .device .type == "mps" and torch .is_floating_point (timesteps ):
439- # mps does not support float64
440- schedule_timesteps = self .timesteps .to (original_samples .device , dtype = torch .float32 )
441- timesteps = timesteps .to (original_samples .device , dtype = torch .float32 )
442- else :
443- schedule_timesteps = self .timesteps .to (original_samples .device )
444- timesteps = timesteps .to (original_samples .device )
461+ """
462+ Applies a masked (discrete) diffusion forward process where batch instance i with timestep t_i is masked at
463+ each position independently with probability 1 - alpha(t_i). Any (continuous) time in [0, 1] can be used, not
464+ just those timesteps that would be in self.timesteps, as masked diffusion models are usually trained in
465+ continuous time.
445466
446- step_indices = [self .index_for_timestep (t , schedule_timesteps ) for t in timesteps ]
467+ Note that `noise` here should be drawn from a uniform distribution over [0, 1], rather than a Gaussian
468+ distribution as is normal for continuous-space diffusion models.
469+ """
470+ # original_samples shape: [B, L]
471+ alphas = self .t_to_alpha (timesteps = timesteps )
472+ alphas = alphas .to (device = original_samples .device , dtype = original_samples .dtype )
447473
448- mask_probs = 1.0 - alphas [ step_indices ]. flatten ()
449- while len (mask_probs ) .shape < len (original_samples .shape ):
450- mask_probs .unsqueeze (- 1 )
474+ mask_probs = 1.0 - alphas
475+ while len (mask_probs .shape ) < len (original_samples .shape ):
476+ mask_probs = mask_probs .unsqueeze (- 1 )
451477
452- mask_indices = (
453- torch .rand (
478+ if noise is None :
479+ noise = torch .rand (
454480 original_samples .shape ,
455481 device = generator .device if generator is not None else original_samples .device ,
456482 generator = generator ,
457483 ).to (original_samples .device )
458- < mask_probs
459- )
484+ mask_indices = noise < mask_probs
460485
461486 masked_samples = original_samples .clone ()
462487 masked_samples [mask_indices ] = self .config .mask_token_id
0 commit comments