Skip to content

Commit 24c062a

Browse files
authored
update check_input for cogview4 (#10966)
fix
1 parent a74f02f commit 24c062a

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,10 +360,16 @@ def check_inputs(
360360
)
361361

362362
if prompt_embeds is not None and negative_prompt_embeds is not None:
363-
if prompt_embeds.shape != negative_prompt_embeds.shape:
363+
if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]:
364364
raise ValueError(
365-
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
366-
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
365+
"`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but"
366+
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
367+
f" {negative_prompt_embeds.shape}."
368+
)
369+
if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]:
370+
raise ValueError(
371+
"`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but"
372+
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
367373
f" {negative_prompt_embeds.shape}."
368374
)
369375

0 commit comments

Comments
 (0)