@@ -99,10 +99,10 @@ def tokenize_prompt(
9999 return_overflowing_tokens = False ,
100100 return_tensors = "pt" ,
101101 )
102-
102+
103103 text_input_ids = text_inputs .input_ids .to (device = device )
104104 attention_mask = text_inputs .attention_mask .to (device = device )
105-
105+
106106 # duplicate text tokens and attention mask for each generation per prompt, using mps friendly method
107107 # TODO: this follows e.g. the Flux pipeline's encode_prompts, why do we repeat in the sequence length dim
108108 # rather than the batch length dim...?
@@ -113,7 +113,7 @@ def tokenize_prompt(
113113 attention_mask = attention_mask .view (batch_size * num_texts_per_prompt , - 1 )
114114
115115 return text_input_ids , attention_mask
116-
116+
117117 def prepare_latents (
118118 self ,
119119 batch_size : int ,
@@ -291,7 +291,7 @@ def __call__(
291291 padding_length = max_sequence_length - prompt_embeds .shape [1 ]
292292 if padding_length > 0 :
293293 padding_mask_tokens = torch .full (
294- (total_batch_size , padding_length ), self .scheduler .config .mask_token_id , device = device
294+ (total_batch_size , padding_length ), self .scheduler .config .mask_token_id , device = device
295295 )
296296 padding_mask_embedding = self .transformer .embed_tokens (padding_mask_tokens )
297297 prompt_embeds = torch .cat ([prompt_embeds , padding_mask_embedding ], dim = 1 )
0 commit comments