Skip to content

Commit 9a8a65b

Browse files
committed
Optimize video display.
1 parent 4586f1e commit 9a8a65b

File tree

1 file changed

+50
-35
lines changed

1 file changed

+50
-35
lines changed

WebUI/webui_pages/dialogue/dialogue.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,9 @@ def on_feedback(
461461
if modelinfo["mtype"] == ModelType.Multimodal:
462462
if running_model == "stable-video-diffusion-img2vid" or running_model == "stable-video-diffusion-img2vid-xt":
463463
return_video = True
464+
464465
if return_video:
465-
chat_box.ai_say("Video generation in progress....")
466+
chat_box.ai_say("")
466467
else:
467468
chat_box.ai_say(["Thinking...", ""])
468469
text = ""
@@ -474,45 +475,59 @@ def on_feedback(
474475
prompt = generate_prompt_for_imagegen(imagegeneration_model, prompt, imageprompt)
475476
imagesprompt = []
476477
history = []
477-
r = api.chat_chat(prompt,
478-
imagesdata=imagesdata,
479-
audiosdata=audiosdata,
480-
videosdata=videosdata,
481-
imagesprompt=imagesprompt,
482-
history=history,
483-
model=running_model,
484-
speechmodel=speechmodel,
485-
prompt_name=prompt_template_name,
486-
temperature=temperature)
487-
for t in r:
488-
if error_msg := check_error_msg(t): # check whether error occured
489-
st.error(error_msg)
490-
break
491-
text += t.get("text", "")
492-
if return_video is False:
493-
chat_box.update_msg(text, element_index=0)
494-
chat_history_id = t.get("chat_history_id", "")
495-
496-
metadata = {
497-
"chat_history_id": chat_history_id,
498-
}
499-
if return_video is False:
500-
chat_box.update_msg(text, element_index=0, streaming=False, metadata=metadata)
501-
502-
if imagegeneration_model and modelinfo["mtype"] != ModelType.Code:
503-
with st.spinner(f"Image generation in progress...."):
504-
gen_image = api.get_image_generation_data(text)
505-
if gen_image:
506-
decoded_data = base64.b64decode(gen_image)
507-
gen_image=Image(BytesIO(decoded_data))
508-
chat_box.update_msg(gen_image, element_index=1, streaming=False)
509-
elif modelinfo["mtype"] == ModelType.Multimodal:
510-
if running_model == "stable-video-diffusion-img2vid" or running_model == "stable-video-diffusion-img2vid-xt":
478+
if return_video:
479+
with st.spinner(f"Video generation in progress...."):
480+
r = api.chat_chat(prompt,
481+
imagesdata=imagesdata,
482+
audiosdata=audiosdata,
483+
videosdata=videosdata,
484+
imagesprompt=imagesprompt,
485+
history=history,
486+
model=running_model,
487+
speechmodel=speechmodel,
488+
prompt_name=prompt_template_name,
489+
temperature=temperature)
490+
for t in r:
491+
if error_msg := check_error_msg(t): # check whether error occured
492+
st.error(error_msg)
493+
break
494+
text += t.get("text", "")
495+
chat_history_id = t.get("chat_history_id", "")
511496
print("video_path: ", text)
512497
with open(text, "rb") as f:
513498
video_bytes = f.read()
514499
gen_video=Video(BytesIO(video_bytes))
515500
chat_box.update_msg(gen_video, streaming=False)
501+
else:
502+
r = api.chat_chat(prompt,
503+
imagesdata=imagesdata,
504+
audiosdata=audiosdata,
505+
videosdata=videosdata,
506+
imagesprompt=imagesprompt,
507+
history=history,
508+
model=running_model,
509+
speechmodel=speechmodel,
510+
prompt_name=prompt_template_name,
511+
temperature=temperature)
512+
for t in r:
513+
if error_msg := check_error_msg(t): # check whether error occured
514+
st.error(error_msg)
515+
break
516+
text += t.get("text", "")
517+
chat_box.update_msg(text, element_index=0)
518+
chat_history_id = t.get("chat_history_id", "")
519+
520+
metadata = {
521+
"chat_history_id": chat_history_id,
522+
}
523+
chat_box.update_msg(text, element_index=0, streaming=False, metadata=metadata)
524+
if imagegeneration_model and modelinfo["mtype"] != ModelType.Code:
525+
with st.spinner(f"Image generation in progress...."):
526+
gen_image = api.get_image_generation_data(text)
527+
if gen_image:
528+
decoded_data = base64.b64decode(gen_image)
529+
gen_image=Image(BytesIO(decoded_data))
530+
chat_box.update_msg(gen_image, element_index=1, streaming=False)
516531
#print("chat_box.history: ", len(chat_box.history))
517532
chat_box.show_feedback(**feedback_kwargs,
518533
key=chat_history_id,

0 commit comments

Comments
 (0)