Skip to content

Commit febe7b5

Browse files
committed
server: update api
1 parent bcba77c commit febe7b5

File tree

1 file changed

+90
-128
lines changed

1 file changed

+90
-128
lines changed

examples/server/main.cpp

Lines changed: 90 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -36,41 +36,6 @@
3636

3737
#include "frontend.cpp"
3838

39-
const char* rng_type_to_str[] = {
40-
"std_default",
41-
"cuda",
42-
};
43-
44-
// Names of the sampler method, same order as enum sample_method in stable-diffusion.h
45-
const char* sample_method_str[] = {
46-
"euler_a",
47-
"euler",
48-
"heun",
49-
"dpm2",
50-
"dpm++2s_a",
51-
"dpm++2m",
52-
"dpm++2mv2",
53-
"ipndm",
54-
"ipndm_v",
55-
"lcm",
56-
};
57-
58-
// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
59-
const char* schedule_str[] = {
60-
"default",
61-
"discrete",
62-
"karras",
63-
"exponential",
64-
"ays",
65-
"gits",
66-
};
67-
68-
enum SDMode {
69-
TXT2IMG,
70-
IMG2IMG,
71-
MODE_COUNT
72-
};
73-
7439
struct SDCtxParams {
7540
std::string model_path;
7641
std::string clip_l_path;
@@ -105,8 +70,6 @@ struct SDRequestParams {
10570
// TODO set to true if esrgan_path is specified in args
10671
bool upscale = false;
10772

108-
SDMode mode = TXT2IMG;
109-
11073
std::string prompt;
11174
std::string negative_prompt;
11275

@@ -195,11 +158,11 @@ void print_params(SDParams params) {
195158
printf(" clip_skip: %d\n", params.lastRequest.clip_skip);
196159
printf(" width: %d\n", params.lastRequest.width);
197160
printf(" height: %d\n", params.lastRequest.height);
198-
printf(" sample_method: %s\n", sample_method_str[params.lastRequest.sample_method]);
199-
printf(" schedule: %s\n", schedule_str[params.ctxParams.schedule]);
161+
printf(" sample_method: %s\n", sd_sample_method_name(params.lastRequest.sample_method));
162+
printf(" schedule: %s\n", sd_schedule_name(params.ctxParams.schedule));
200163
printf(" sample_steps: %d\n", params.lastRequest.sample_steps);
201164
printf(" strength(img2img): %.2f\n", params.lastRequest.strength);
202-
printf(" rng: %s\n", rng_type_to_str[params.ctxParams.rng_type]);
165+
printf(" rng: %s\n", sd_rng_type_name(params.ctxParams.rng_type));
203166
printf(" seed: %ld\n", params.lastRequest.seed);
204167
printf(" batch_count: %d\n", params.lastRequest.batch_count);
205168
printf(" vae_tiling: %s\n", params.ctxParams.vae_tiling ? "true" : "false");
@@ -512,17 +475,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
512475
break;
513476
}
514477
const char* schedule_selected = argv[i];
515-
int schedule_found = -1;
516-
for (int d = 0; d < N_SCHEDULES; d++) {
517-
if (!strcmp(schedule_selected, schedule_str[d])) {
518-
schedule_found = d;
519-
}
520-
}
521-
if (schedule_found == -1) {
478+
schedule_t schedule_found = str_to_schedule(schedule_selected);
479+
if (schedule_found == SCHEDULE_COUNT) {
522480
invalid_arg = true;
523481
break;
524482
}
525-
params.ctxParams.schedule = (schedule_t)schedule_found;
483+
params.ctxParams.schedule = schedule_found;
526484
} else if (arg == "-s" || arg == "--seed") {
527485
if (++i >= argc) {
528486
invalid_arg = true;
@@ -535,13 +493,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
535493
break;
536494
}
537495
const char* sample_method_selected = argv[i];
538-
int sample_method_found = -1;
539-
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
540-
if (!strcmp(sample_method_selected, sample_method_str[m])) {
541-
sample_method_found = m;
542-
}
543-
}
544-
if (sample_method_found == -1) {
496+
int sample_method_found = str_to_sample_method(sample_method_selected);
497+
if (sample_method_found == SAMPLE_METHOD_COUNT) {
545498
invalid_arg = true;
546499
break;
547500
}
@@ -689,8 +642,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
689642
parameter_string += "Seed: " + std::to_string(seed) + ", ";
690643
parameter_string += "Size: " + std::to_string(params.lastRequest.width) + "x" + std::to_string(params.lastRequest.height) + ", ";
691644
parameter_string += "Model: " + sd_basename(params.ctxParams.model_path) + ", ";
692-
parameter_string += "RNG: " + std::string(rng_type_to_str[params.ctxParams.rng_type]) + ", ";
693-
parameter_string += "Sampler: " + std::string(sample_method_str[params.lastRequest.sample_method]);
645+
parameter_string += "RNG: " + std::string(sd_rng_type_name(params.ctxParams.rng_type)) + ", ";
646+
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.lastRequest.sample_method));
694647
if (params.ctxParams.schedule == KARRAS) {
695648
parameter_string += " karras";
696649
}
@@ -807,14 +760,9 @@ bool parseJsonPrompt(std::string json_str, SDParams* params) {
807760
try {
808761
std::string sample_method = payload["sample_method"];
809762

810-
int sample_method_found = -1;
811-
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
812-
if (!strcmp(sample_method.c_str(), sample_method_str[m])) {
813-
sample_method_found = m;
814-
}
815-
}
816-
if (sample_method_found >= 0) {
817-
params->lastRequest.sample_method = (sample_method_t)sample_method_found;
763+
sample_method_t sample_method_found = str_to_sample_method(sample_method.c_str());
764+
if (sample_method_found != SAMPLE_METHOD_COUNT) {
765+
params->lastRequest.sample_method = sample_method_found;
818766
} else {
819767
sd_log(sd_log_level_t::SD_LOG_WARN, "Unknown sampling method: %s\n", sample_method.c_str());
820768
}
@@ -1011,16 +959,11 @@ bool parseJsonPrompt(std::string json_str, SDParams* params) {
1011959
}
1012960

1013961
try {
1014-
std::string schedule = payload["schedule"];
1015-
int schedule_found = -1;
1016-
for (int m = 0; m < N_SCHEDULES; m++) {
1017-
if (!strcmp(schedule.c_str(), schedule_str[m])) {
1018-
schedule_found = m;
1019-
}
1020-
}
1021-
if (schedule_found >= 0) {
1022-
if (params->ctxParams.schedule != (schedule_t)schedule_found) {
1023-
params->ctxParams.schedule = (schedule_t)schedule_found;
962+
std::string schedule = payload["schedule"];
963+
schedule_t schedule_found = str_to_schedule(schedule.c_str());
964+
if (schedule_found != SCHEDULE_COUNT) {
965+
if (params->ctxParams.schedule != schedule_found) {
966+
params->ctxParams.schedule = schedule_found;
1024967
updatectx = true;
1025968
}
1026969
} else {
@@ -1189,30 +1132,31 @@ void start_server(SDParams params) {
11891132
std::lock_guard<std::mutex> results_lock(results_mutex);
11901133
task_results[task_id] = task_json;
11911134
}
1192-
1193-
sd_ctx = new_sd_ctx(params.ctxParams.model_path.c_str(),
1194-
params.ctxParams.clip_l_path.c_str(),
1195-
params.ctxParams.clip_g_path.c_str(),
1196-
params.ctxParams.t5xxl_path.c_str(),
1197-
params.ctxParams.diffusion_model_path.c_str(),
1198-
params.ctxParams.vae_path.c_str(),
1199-
params.ctxParams.taesd_path.c_str(),
1200-
params.ctxParams.controlnet_path.c_str(),
1201-
params.ctxParams.lora_model_dir.c_str(),
1202-
params.ctxParams.embeddings_path.c_str(),
1203-
params.ctxParams.stacked_id_embeddings_path.c_str(),
1204-
params.ctxParams.vae_decode_only,
1205-
params.ctxParams.vae_tiling,
1206-
false,
1207-
params.ctxParams.n_threads,
1208-
params.ctxParams.wtype,
1209-
params.ctxParams.rng_type,
1210-
params.ctxParams.schedule,
1211-
params.ctxParams.clip_on_cpu,
1212-
params.ctxParams.control_net_cpu,
1213-
params.ctxParams.vae_on_cpu,
1214-
params.ctxParams.diffusion_flash_attn,
1215-
true, false, 1);
1135+
sd_ctx_params_t sd_ctx_params = {
1136+
params.ctxParams.model_path.c_str(),
1137+
params.ctxParams.clip_l_path.c_str(),
1138+
params.ctxParams.clip_g_path.c_str(),
1139+
params.ctxParams.t5xxl_path.c_str(),
1140+
params.ctxParams.diffusion_model_path.c_str(),
1141+
params.ctxParams.vae_path.c_str(),
1142+
params.ctxParams.taesd_path.c_str(),
1143+
params.ctxParams.controlnet_path.c_str(),
1144+
params.ctxParams.lora_model_dir.c_str(),
1145+
params.ctxParams.embeddings_path.c_str(),
1146+
params.ctxParams.stacked_id_embeddings_path.c_str(),
1147+
params.ctxParams.vae_decode_only,
1148+
params.ctxParams.vae_tiling,
1149+
false,
1150+
params.ctxParams.n_threads,
1151+
params.ctxParams.wtype,
1152+
params.ctxParams.rng_type,
1153+
params.ctxParams.schedule,
1154+
params.ctxParams.clip_on_cpu,
1155+
params.ctxParams.control_net_cpu,
1156+
params.ctxParams.vae_on_cpu,
1157+
params.ctxParams.diffusion_flash_attn,
1158+
true, false, 1};
1159+
sd_ctx = new_sd_ctx(&sd_ctx_params);
12161160
if (sd_ctx == NULL) {
12171161
printf("new_sd_ctx_t failed\n");
12181162
std::lock_guard<std::mutex> results_lock(results_mutex);
@@ -1235,29 +1179,47 @@ void start_server(SDParams params) {
12351179

12361180
{
12371181
sd_image_t* results;
1238-
results = txt2img(sd_ctx,
1239-
params.lastRequest.prompt.c_str(),
1240-
params.lastRequest.negative_prompt.c_str(),
1241-
params.lastRequest.clip_skip,
1242-
params.lastRequest.cfg_scale,
1243-
params.lastRequest.guidance,
1244-
0.f, // eta
1245-
params.lastRequest.width,
1246-
params.lastRequest.height,
1247-
params.lastRequest.sample_method,
1248-
params.lastRequest.sample_steps,
1249-
params.lastRequest.seed,
1250-
params.lastRequest.batch_count,
1251-
NULL,
1252-
1,
1253-
params.lastRequest.style_ratio,
1254-
params.lastRequest.normalize_input,
1255-
params.input_id_images_path.c_str(),
1256-
params.lastRequest.skip_layers.data(),
1257-
params.lastRequest.skip_layers.size(),
1258-
params.lastRequest.slg_scale,
1259-
params.lastRequest.skip_layer_start,
1260-
params.lastRequest.skip_layer_end);
1182+
sd_slg_params_t slg = {
1183+
params.lastRequest.skip_layers.data(),
1184+
params.lastRequest.skip_layers.size(),
1185+
params.lastRequest.skip_layer_start,
1186+
params.lastRequest.skip_layer_end,
1187+
params.lastRequest.slg_scale};
1188+
sd_guidance_params_t guidance = {
1189+
params.lastRequest.cfg_scale,
1190+
params.lastRequest.cfg_scale,
1191+
params.lastRequest.cfg_scale,
1192+
params.lastRequest.guidance,
1193+
slg};
1194+
sd_image_t input_image = {
1195+
(uint32_t)params.lastRequest.width,
1196+
(uint32_t)params.lastRequest.height,
1197+
3,
1198+
NULL};
1199+
sd_image_t mask_img = input_image;
1200+
sd_img_gen_params_t gen_params = {
1201+
params.lastRequest.prompt.c_str(),
1202+
params.lastRequest.negative_prompt.c_str(),
1203+
params.lastRequest.clip_skip,
1204+
guidance,
1205+
input_image,
1206+
NULL, // ref images
1207+
0, // ref images count
1208+
mask_img,
1209+
params.lastRequest.width,
1210+
params.lastRequest.height,
1211+
params.lastRequest.sample_method,
1212+
params.lastRequest.sample_steps,
1213+
0.f, // eta
1214+
1.f, // denoise strength
1215+
params.lastRequest.seed,
1216+
params.lastRequest.batch_count,
1217+
NULL, // control image ptr
1218+
1.f, // control strength
1219+
params.lastRequest.style_ratio,
1220+
params.lastRequest.normalize_input,
1221+
params.input_id_images_path.c_str()};
1222+
results = generate_image(sd_ctx, &gen_params);
12611223

12621224
if (results == NULL) {
12631225
printf("generate failed\n");
@@ -1328,7 +1290,7 @@ void start_server(SDParams params) {
13281290
params_json["guidance"] = params.lastRequest.guidance;
13291291
params_json["width"] = params.lastRequest.width;
13301292
params_json["height"] = params.lastRequest.height;
1331-
params_json["sample_method"] = sample_method_str[params.lastRequest.sample_method];
1293+
params_json["sample_method"] = sd_sample_method_name(params.lastRequest.sample_method);
13321294
params_json["sample_steps"] = params.lastRequest.sample_steps;
13331295
params_json["seed"] = params.lastRequest.seed;
13341296
params_json["batch_count"] = params.lastRequest.batch_count;
@@ -1352,7 +1314,7 @@ void start_server(SDParams params) {
13521314
context_params["n_threads"] = params.ctxParams.n_threads;
13531315
context_params["wtype"] = params.ctxParams.wtype;
13541316
context_params["rng_type"] = params.ctxParams.rng_type;
1355-
context_params["schedule"] = schedule_str[params.ctxParams.schedule];
1317+
context_params["schedule"] = sd_schedule_name(params.ctxParams.schedule);
13561318
context_params["clip_on_cpu"] = params.ctxParams.clip_on_cpu;
13571319
context_params["control_net_cpu"] = params.ctxParams.control_net_cpu;
13581320
context_params["vae_on_cpu"] = params.ctxParams.vae_on_cpu;
@@ -1390,17 +1352,17 @@ void start_server(SDParams params) {
13901352
svr->Get("/sample_methods", [](const httplib::Request& req, httplib::Response& res) {
13911353
using json = nlohmann::json;
13921354
json response;
1393-
for (int m = 0; m < N_SAMPLE_METHODS; m++) {
1394-
response.push_back(sample_method_str[m]);
1355+
for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) {
1356+
response.push_back(sd_sample_method_name((sample_method_t)m));
13951357
}
13961358
res.set_content(response.dump(), "application/json");
13971359
});
13981360

13991361
svr->Get("/schedules", [](const httplib::Request& req, httplib::Response& res) {
14001362
using json = nlohmann::json;
14011363
json response;
1402-
for (int s = 0; s < N_SCHEDULES; s++) {
1403-
response.push_back(schedule_str[s]);
1364+
for (int s = 0; s < SCHEDULE_COUNT; s++) {
1365+
response.push_back(sd_schedule_name((schedule_t)s));
14041366
}
14051367
res.set_content(response.dump(), "application/json");
14061368
});

0 commit comments

Comments
 (0)