Skip to content

Commit d40c86b

Browse files
committed
fix
1 parent 66bf7ea commit d40c86b

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,16 @@ def check_inputs(
362362
)
363363

364364
if prompt_embeds is not None and negative_prompt_embeds is not None:
365-
if prompt_embeds.shape != negative_prompt_embeds.shape:
365+
if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]:
366366
raise ValueError(
367-
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
368-
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
367+
"`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but"
368+
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
369+
f" {negative_prompt_embeds.shape}."
370+
)
371+
if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]:
372+
raise ValueError(
373+
"`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but"
374+
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
369375
f" {negative_prompt_embeds.shape}."
370376
)
371377

0 commit comments

Comments
 (0)