@@ -461,8 +461,9 @@ def on_feedback(
461
461
if modelinfo ["mtype" ] == ModelType .Multimodal :
462
462
if running_model == "stable-video-diffusion-img2vid" or running_model == "stable-video-diffusion-img2vid-xt" :
463
463
return_video = True
464
+
464
465
if return_video :
465
- chat_box .ai_say ("Video generation in progress.... " )
466
+ chat_box .ai_say ("" )
466
467
else :
467
468
chat_box .ai_say (["Thinking..." , "" ])
468
469
text = ""
@@ -474,45 +475,59 @@ def on_feedback(
474
475
prompt = generate_prompt_for_imagegen (imagegeneration_model , prompt , imageprompt )
475
476
imagesprompt = []
476
477
history = []
477
- r = api .chat_chat (prompt ,
478
- imagesdata = imagesdata ,
479
- audiosdata = audiosdata ,
480
- videosdata = videosdata ,
481
- imagesprompt = imagesprompt ,
482
- history = history ,
483
- model = running_model ,
484
- speechmodel = speechmodel ,
485
- prompt_name = prompt_template_name ,
486
- temperature = temperature )
487
- for t in r :
488
- if error_msg := check_error_msg (t ): # check whether error occured
489
- st .error (error_msg )
490
- break
491
- text += t .get ("text" , "" )
492
- if return_video is False :
493
- chat_box .update_msg (text , element_index = 0 )
494
- chat_history_id = t .get ("chat_history_id" , "" )
495
-
496
- metadata = {
497
- "chat_history_id" : chat_history_id ,
498
- }
499
- if return_video is False :
500
- chat_box .update_msg (text , element_index = 0 , streaming = False , metadata = metadata )
501
-
502
- if imagegeneration_model and modelinfo ["mtype" ] != ModelType .Code :
503
- with st .spinner (f"Image generation in progress...." ):
504
- gen_image = api .get_image_generation_data (text )
505
- if gen_image :
506
- decoded_data = base64 .b64decode (gen_image )
507
- gen_image = Image (BytesIO (decoded_data ))
508
- chat_box .update_msg (gen_image , element_index = 1 , streaming = False )
509
- elif modelinfo ["mtype" ] == ModelType .Multimodal :
510
- if running_model == "stable-video-diffusion-img2vid" or running_model == "stable-video-diffusion-img2vid-xt" :
478
+ if return_video :
479
+ with st .spinner (f"Video generation in progress...." ):
480
+ r = api .chat_chat (prompt ,
481
+ imagesdata = imagesdata ,
482
+ audiosdata = audiosdata ,
483
+ videosdata = videosdata ,
484
+ imagesprompt = imagesprompt ,
485
+ history = history ,
486
+ model = running_model ,
487
+ speechmodel = speechmodel ,
488
+ prompt_name = prompt_template_name ,
489
+ temperature = temperature )
490
+ for t in r :
491
+ if error_msg := check_error_msg (t ): # check whether error occured
492
+ st .error (error_msg )
493
+ break
494
+ text += t .get ("text" , "" )
495
+ chat_history_id = t .get ("chat_history_id" , "" )
511
496
print ("video_path: " , text )
512
497
with open (text , "rb" ) as f :
513
498
video_bytes = f .read ()
514
499
gen_video = Video (BytesIO (video_bytes ))
515
500
chat_box .update_msg (gen_video , streaming = False )
501
+ else :
502
+ r = api .chat_chat (prompt ,
503
+ imagesdata = imagesdata ,
504
+ audiosdata = audiosdata ,
505
+ videosdata = videosdata ,
506
+ imagesprompt = imagesprompt ,
507
+ history = history ,
508
+ model = running_model ,
509
+ speechmodel = speechmodel ,
510
+ prompt_name = prompt_template_name ,
511
+ temperature = temperature )
512
+ for t in r :
513
+ if error_msg := check_error_msg (t ): # check whether error occured
514
+ st .error (error_msg )
515
+ break
516
+ text += t .get ("text" , "" )
517
+ chat_box .update_msg (text , element_index = 0 )
518
+ chat_history_id = t .get ("chat_history_id" , "" )
519
+
520
+ metadata = {
521
+ "chat_history_id" : chat_history_id ,
522
+ }
523
+ chat_box .update_msg (text , element_index = 0 , streaming = False , metadata = metadata )
524
+ if imagegeneration_model and modelinfo ["mtype" ] != ModelType .Code :
525
+ with st .spinner (f"Image generation in progress...." ):
526
+ gen_image = api .get_image_generation_data (text )
527
+ if gen_image :
528
+ decoded_data = base64 .b64decode (gen_image )
529
+ gen_image = Image (BytesIO (decoded_data ))
530
+ chat_box .update_msg (gen_image , element_index = 1 , streaming = False )
516
531
#print("chat_box.history: ", len(chat_box.history))
517
532
chat_box .show_feedback (** feedback_kwargs ,
518
533
key = chat_history_id ,
0 commit comments