Skip to content

Commit 35d7f88

Browse files
committed
feat: Added model option to handle_conditional
1 parent f543155 commit 35d7f88

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

examples/server/server.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,15 @@ int main(int argc, const char ** argv) {
666666
res_ok_audio(res, audio, mime_type);
667667
};
668668

669-
const auto handle_conditional = [&args, &tqueue, &rmap, &res_error, &res_ok_json](const httplib::Request & req, httplib::Response & res) {
669+
const auto handle_conditional = [
670+
&args,
671+
&tqueue,
672+
&rmap,
673+
&res_error,
674+
&res_ok_json,
675+
&model_map,
676+
&default_model
677+
](const httplib::Request & req, httplib::Response & res) {
670678
if (args.get_string_param("--text-encoder-path").size() == 0) {
671679
json formatted_error = format_error_response("A '--text-encoder-path' must be specified for conditional generation.", ERROR_TYPE_NOT_SUPPORTED);
672680
res_error(res, formatted_error);
@@ -685,6 +693,20 @@ int main(int argc, const char ** argv) {
685693
}
686694
std::string prompt = data.at("input").get<std::string>();
687695
struct simple_text_prompt_task * task = new simple_text_prompt_task(CONDITIONAL_PROMPT, prompt);
696+
697+
if (data.contains("model") && data.at("model").is_string()) {
698+
const std::string model = data.at("model");
699+
if (!model_map.contains(model)) {
700+
const std::string message = std::format("Invalid Model: {0}", model);
701+
json formatted_error = format_error_response(message, ERROR_TYPE_INVALID_REQUEST);
702+
res_error(res, formatted_error);
703+
return;
704+
}
705+
task->model = data.at("model").get<std::string>();
706+
} else {
707+
task->model = default_model;
708+
}
709+
688710
int id = task->id;
689711
tqueue->push(task);
690712
struct simple_text_prompt_task * rtask = rmap->get(id);

0 commit comments

Comments
 (0)