@@ -318,19 +318,21 @@ def set_timesteps(
318318 self .num_inference_steps = num_inference_steps
319319
320320 if self .config .timestep_discretization == "linear" :
321- timesteps = torch .linspace (1.0 , self .config .final_timestep , num_inference_steps + 1 , device = device )
321+ timesteps = torch .linspace (1.0 , self .config .final_timestep , num_inference_steps + 1 )
322322 elif self .config .timestep_discretization == "cosine" :
323323 timesteps = torch .linspace (1.0 , self .config .final_timestep , num_inference_steps + 1 )
324- timesteps = torch .cos ((torch .pi / 2 ) * (1.0 - timesteps )). to ( device )
324+ timesteps = torch .cos ((torch .pi / 2 ) * (1.0 - timesteps ))
325325 else :
326326 raise ValueError (
327327 f"{ self .config .timestep_discretization } is not a supported timestep discretization strategy. Current "
328328 f"supported strategies are `linear` and `cosine`."
329329 )
330- self .timesteps = timesteps
330+ # Omit the final timestep so that len(self.timesteps) == num_inference_steps
331+ self .timesteps = timesteps [:- 1 ].to (device = device )
331332
332333 # Now calculate the masking or noise schedule (alpha) values at the chosen timestep discretization
333- self .alphas = self .t_to_alpha (timesteps = self .timesteps ).to (device = device )
334+ # NOTE: the masking/alpha schedule is one element longer than self.timesteps (len num_inference_steps + 1)
335+ self .alphas = self .t_to_alpha (timesteps = timesteps ).to (device = device )
334336
335337 # Allow overriding of specific sampling parameters (temperature, top_p, etc.)
336338 if temperature is None :
@@ -375,8 +377,6 @@ def step(
375377 # model_output shape: [B, L, V]
376378 # sample shape: [B, L] (sequence of discrete tokens)
377379 step_idx = self .index_for_timestep (timestep )
378- t = self .timesteps [step_idx ] # Current timestep
379- s = self .timesteps [step_idx + 1 ] # Previous timestep (next-largest timestep not yet processed)
380380 temperature = self .temperatures [step_idx ] if self .temperatures is not None else 1.0
381381 top_p = self .top_p_schedule [step_idx ] if self .top_p_schedule is not None else None
382382 top_k = self .top_k_schedule [step_idx ] if self .top_k_schedule is not None else None
@@ -433,14 +433,19 @@ def step(
433433
434434 if num_tokens_to_unmask > 0 :
435435 if self .config .alg_temperature is None or self .config .alg_temperature == 0 :
436- _ , unmask_index = torch .topk (full_confidence , num_tokens_to_unmask )
436+ _ , unmask_indices = torch .topk (full_confidence , num_tokens_to_unmask )
437437 else :
438438 full_confidence = full_confidence / self .config .alg_temperature
439439 full_confidence = F .softmax (full_confidence , dim = - 1 )
440- unmask_index = torch .multinomial (full_confidence , num_samples = num_tokens_to_unmask )
441-
442- prev_sample = torch .zeros_like (sample , device = sample .device )
443- prev_sample = torch .where (unmask_index , pred_original_sample , sample )
440+ unmask_indices = torch .multinomial (
441+ full_confidence , num_samples = num_tokens_to_unmask , generator = generator
442+ )
443+ unmask_indices = unmask_indices .to (sample .device )
444+
445+ row_indices = torch .arange (sample .size (0 ), device = sample .device ).unsqueeze (1 ).expand_as (unmask_indices )
446+ prev_sample = sample .clone ()
447+ # Unmask at the chosen indices with values from the pred_original_sample
448+ prev_sample [row_indices , unmask_indices ] = pred_original_sample [row_indices , unmask_indices ]
444449 else :
445450 # No tokens to unmask, so the sample should stay the same
446451 prev_sample = sample
0 commit comments