Skip to content

Commit c78d823

Browse files
committed
Fix bugs in DreamTextPipeline.__call__
1 parent cae9ddf commit c78d823

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/pipelines/dream/pipeline_dream.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def __call__(
272272
device=device,
273273
)
274274
else:
275-
text_ids, attention_mask = None
275+
text_ids, attention_mask = None, None
276276

277277
# 4. Prepare latent variables (e.g. the initial sample) for generation
278278
total_batch_size = batch_size * num_texts_per_prompt
@@ -284,7 +284,7 @@ def __call__(
284284
device=device,
285285
)
286286

287-
if prompt_embeds is not None:
287+
if prompt_embeds is None:
288288
prompt_embeds = self.transformer.embed_tokens(latents)
289289
else:
290290
# If prompt_embeds's seq len is not max_sequence_length, concat with embedding of mask tokens for the

0 commit comments

Comments
 (0)