@@ -666,7 +666,15 @@ int main(int argc, const char ** argv) {
666
666
res_ok_audio (res, audio, mime_type);
667
667
};
668
668
669
- const auto handle_conditional = [&args, &tqueue, &rmap, &res_error, &res_ok_json](const httplib::Request & req, httplib::Response & res) {
669
+ const auto handle_conditional = [
670
+ &args,
671
+ &tqueue,
672
+ &rmap,
673
+ &res_error,
674
+ &res_ok_json,
675
+ &model_map,
676
+ &default_model
677
+ ](const httplib::Request & req, httplib::Response & res) {
670
678
if (args.get_string_param (" --text-encoder-path" ).size () == 0 ) {
671
679
json formatted_error = format_error_response (" A '--text-encoder-path' must be specified for conditional generation." , ERROR_TYPE_NOT_SUPPORTED);
672
680
res_error (res, formatted_error);
@@ -685,6 +693,20 @@ int main(int argc, const char ** argv) {
685
693
}
686
694
std::string prompt = data.at (" input" ).get <std::string>();
687
695
struct simple_text_prompt_task * task = new simple_text_prompt_task (CONDITIONAL_PROMPT, prompt);
696
+
697
+ if (data.contains (" model" ) && data.at (" model" ).is_string ()) {
698
+ const std::string model = data.at (" model" );
699
+ if (!model_map.contains (model)) {
700
+ const std::string message = std::format (" Invalid Model: {0}" , model);
701
+ json formatted_error = format_error_response (message, ERROR_TYPE_INVALID_REQUEST);
702
+ res_error (res, formatted_error);
703
+ return ;
704
+ }
705
+ task->model = data.at (" model" ).get <std::string>();
706
+ } else {
707
+ task->model = default_model;
708
+ }
709
+
688
710
int id = task->id ;
689
711
tqueue->push (task);
690
712
struct simple_text_prompt_task * rtask = rmap->get (id);
0 commit comments