36
36
37
37
#include " frontend.cpp"
38
38
39
- const char * rng_type_to_str[] = {
40
- " std_default" ,
41
- " cuda" ,
42
- };
43
-
44
- // Names of the sampler method, same order as enum sample_method in stable-diffusion.h
45
- const char * sample_method_str[] = {
46
- " euler_a" ,
47
- " euler" ,
48
- " heun" ,
49
- " dpm2" ,
50
- " dpm++2s_a" ,
51
- " dpm++2m" ,
52
- " dpm++2mv2" ,
53
- " ipndm" ,
54
- " ipndm_v" ,
55
- " lcm" ,
56
- };
57
-
58
- // Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
59
- const char * schedule_str[] = {
60
- " default" ,
61
- " discrete" ,
62
- " karras" ,
63
- " exponential" ,
64
- " ays" ,
65
- " gits" ,
66
- };
67
-
68
- enum SDMode {
69
- TXT2IMG,
70
- IMG2IMG,
71
- MODE_COUNT
72
- };
73
-
74
39
struct SDCtxParams {
75
40
std::string model_path;
76
41
std::string clip_l_path;
@@ -105,8 +70,6 @@ struct SDRequestParams {
105
70
// TODO set to true if esrgan_path is specified in args
106
71
bool upscale = false ;
107
72
108
- SDMode mode = TXT2IMG;
109
-
110
73
std::string prompt;
111
74
std::string negative_prompt;
112
75
@@ -195,11 +158,11 @@ void print_params(SDParams params) {
195
158
printf (" clip_skip: %d\n " , params.lastRequest .clip_skip );
196
159
printf (" width: %d\n " , params.lastRequest .width );
197
160
printf (" height: %d\n " , params.lastRequest .height );
198
- printf (" sample_method: %s\n " , sample_method_str[ params.lastRequest .sample_method ] );
199
- printf (" schedule: %s\n " , schedule_str[ params.ctxParams .schedule ] );
161
+ printf (" sample_method: %s\n " , sd_sample_method_name ( params.lastRequest .sample_method ) );
162
+ printf (" schedule: %s\n " , sd_schedule_name ( params.ctxParams .schedule ) );
200
163
printf (" sample_steps: %d\n " , params.lastRequest .sample_steps );
201
164
printf (" strength(img2img): %.2f\n " , params.lastRequest .strength );
202
- printf (" rng: %s\n " , rng_type_to_str[ params.ctxParams .rng_type ] );
165
+ printf (" rng: %s\n " , sd_rng_type_name ( params.ctxParams .rng_type ) );
203
166
printf (" seed: %ld\n " , params.lastRequest .seed );
204
167
printf (" batch_count: %d\n " , params.lastRequest .batch_count );
205
168
printf (" vae_tiling: %s\n " , params.ctxParams .vae_tiling ? " true" : " false" );
@@ -512,17 +475,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
512
475
break ;
513
476
}
514
477
const char * schedule_selected = argv[i];
515
- int schedule_found = -1 ;
516
- for (int d = 0 ; d < N_SCHEDULES; d++) {
517
- if (!strcmp (schedule_selected, schedule_str[d])) {
518
- schedule_found = d;
519
- }
520
- }
521
- if (schedule_found == -1 ) {
478
+ schedule_t schedule_found = str_to_schedule (schedule_selected);
479
+ if (schedule_found == SCHEDULE_COUNT) {
522
480
invalid_arg = true ;
523
481
break ;
524
482
}
525
- params.ctxParams .schedule = ( schedule_t ) schedule_found;
483
+ params.ctxParams .schedule = schedule_found;
526
484
} else if (arg == " -s" || arg == " --seed" ) {
527
485
if (++i >= argc) {
528
486
invalid_arg = true ;
@@ -535,13 +493,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
535
493
break ;
536
494
}
537
495
const char * sample_method_selected = argv[i];
538
- int sample_method_found = -1 ;
539
- for (int m = 0 ; m < N_SAMPLE_METHODS; m++) {
540
- if (!strcmp (sample_method_selected, sample_method_str[m])) {
541
- sample_method_found = m;
542
- }
543
- }
544
- if (sample_method_found == -1 ) {
496
+ int sample_method_found = str_to_sample_method (sample_method_selected);
497
+ if (sample_method_found == SAMPLE_METHOD_COUNT) {
545
498
invalid_arg = true ;
546
499
break ;
547
500
}
@@ -689,8 +642,8 @@ std::string get_image_params(SDParams params, int64_t seed) {
689
642
parameter_string += " Seed: " + std::to_string (seed) + " , " ;
690
643
parameter_string += " Size: " + std::to_string (params.lastRequest .width ) + " x" + std::to_string (params.lastRequest .height ) + " , " ;
691
644
parameter_string += " Model: " + sd_basename (params.ctxParams .model_path ) + " , " ;
692
- parameter_string += " RNG: " + std::string (rng_type_to_str[ params.ctxParams .rng_type ] ) + " , " ;
693
- parameter_string += " Sampler: " + std::string (sample_method_str[ params.lastRequest .sample_method ] );
645
+ parameter_string += " RNG: " + std::string (sd_rng_type_name ( params.ctxParams .rng_type ) ) + " , " ;
646
+ parameter_string += " Sampler: " + std::string (sd_sample_method_name ( params.lastRequest .sample_method ) );
694
647
if (params.ctxParams .schedule == KARRAS) {
695
648
parameter_string += " karras" ;
696
649
}
@@ -807,14 +760,9 @@ bool parseJsonPrompt(std::string json_str, SDParams* params) {
807
760
try {
808
761
std::string sample_method = payload[" sample_method" ];
809
762
810
- int sample_method_found = -1 ;
811
- for (int m = 0 ; m < N_SAMPLE_METHODS; m++) {
812
- if (!strcmp (sample_method.c_str (), sample_method_str[m])) {
813
- sample_method_found = m;
814
- }
815
- }
816
- if (sample_method_found >= 0 ) {
817
- params->lastRequest .sample_method = (sample_method_t )sample_method_found;
763
+ sample_method_t sample_method_found = str_to_sample_method (sample_method.c_str ());
764
+ if (sample_method_found != SAMPLE_METHOD_COUNT) {
765
+ params->lastRequest .sample_method = sample_method_found;
818
766
} else {
819
767
sd_log (sd_log_level_t ::SD_LOG_WARN, " Unknown sampling method: %s\n " , sample_method.c_str ());
820
768
}
@@ -1011,16 +959,11 @@ bool parseJsonPrompt(std::string json_str, SDParams* params) {
1011
959
}
1012
960
1013
961
try {
1014
- std::string schedule = payload[" schedule" ];
1015
- int schedule_found = -1 ;
1016
- for (int m = 0 ; m < N_SCHEDULES; m++) {
1017
- if (!strcmp (schedule.c_str (), schedule_str[m])) {
1018
- schedule_found = m;
1019
- }
1020
- }
1021
- if (schedule_found >= 0 ) {
1022
- if (params->ctxParams .schedule != (schedule_t )schedule_found) {
1023
- params->ctxParams .schedule = (schedule_t )schedule_found;
962
+ std::string schedule = payload[" schedule" ];
963
+ schedule_t schedule_found = str_to_schedule (schedule.c_str ());
964
+ if (schedule_found != SCHEDULE_COUNT) {
965
+ if (params->ctxParams .schedule != schedule_found) {
966
+ params->ctxParams .schedule = schedule_found;
1024
967
updatectx = true ;
1025
968
}
1026
969
} else {
@@ -1189,30 +1132,31 @@ void start_server(SDParams params) {
1189
1132
std::lock_guard<std::mutex> results_lock (results_mutex);
1190
1133
task_results[task_id] = task_json;
1191
1134
}
1192
-
1193
- sd_ctx = new_sd_ctx (params.ctxParams .model_path .c_str (),
1194
- params.ctxParams .clip_l_path .c_str (),
1195
- params.ctxParams .clip_g_path .c_str (),
1196
- params.ctxParams .t5xxl_path .c_str (),
1197
- params.ctxParams .diffusion_model_path .c_str (),
1198
- params.ctxParams .vae_path .c_str (),
1199
- params.ctxParams .taesd_path .c_str (),
1200
- params.ctxParams .controlnet_path .c_str (),
1201
- params.ctxParams .lora_model_dir .c_str (),
1202
- params.ctxParams .embeddings_path .c_str (),
1203
- params.ctxParams .stacked_id_embeddings_path .c_str (),
1204
- params.ctxParams .vae_decode_only ,
1205
- params.ctxParams .vae_tiling ,
1206
- false ,
1207
- params.ctxParams .n_threads ,
1208
- params.ctxParams .wtype ,
1209
- params.ctxParams .rng_type ,
1210
- params.ctxParams .schedule ,
1211
- params.ctxParams .clip_on_cpu ,
1212
- params.ctxParams .control_net_cpu ,
1213
- params.ctxParams .vae_on_cpu ,
1214
- params.ctxParams .diffusion_flash_attn ,
1215
- true , false , 1 );
1135
+ sd_ctx_params_t sd_ctx_params = {
1136
+ params.ctxParams .model_path .c_str (),
1137
+ params.ctxParams .clip_l_path .c_str (),
1138
+ params.ctxParams .clip_g_path .c_str (),
1139
+ params.ctxParams .t5xxl_path .c_str (),
1140
+ params.ctxParams .diffusion_model_path .c_str (),
1141
+ params.ctxParams .vae_path .c_str (),
1142
+ params.ctxParams .taesd_path .c_str (),
1143
+ params.ctxParams .controlnet_path .c_str (),
1144
+ params.ctxParams .lora_model_dir .c_str (),
1145
+ params.ctxParams .embeddings_path .c_str (),
1146
+ params.ctxParams .stacked_id_embeddings_path .c_str (),
1147
+ params.ctxParams .vae_decode_only ,
1148
+ params.ctxParams .vae_tiling ,
1149
+ false ,
1150
+ params.ctxParams .n_threads ,
1151
+ params.ctxParams .wtype ,
1152
+ params.ctxParams .rng_type ,
1153
+ params.ctxParams .schedule ,
1154
+ params.ctxParams .clip_on_cpu ,
1155
+ params.ctxParams .control_net_cpu ,
1156
+ params.ctxParams .vae_on_cpu ,
1157
+ params.ctxParams .diffusion_flash_attn ,
1158
+ true , false , 1 };
1159
+ sd_ctx = new_sd_ctx (&sd_ctx_params);
1216
1160
if (sd_ctx == NULL ) {
1217
1161
printf (" new_sd_ctx_t failed\n " );
1218
1162
std::lock_guard<std::mutex> results_lock (results_mutex);
@@ -1235,29 +1179,47 @@ void start_server(SDParams params) {
1235
1179
1236
1180
{
1237
1181
sd_image_t * results;
1238
- results = txt2img (sd_ctx,
1239
- params.lastRequest .prompt .c_str (),
1240
- params.lastRequest .negative_prompt .c_str (),
1241
- params.lastRequest .clip_skip ,
1242
- params.lastRequest .cfg_scale ,
1243
- params.lastRequest .guidance ,
1244
- 0 .f , // eta
1245
- params.lastRequest .width ,
1246
- params.lastRequest .height ,
1247
- params.lastRequest .sample_method ,
1248
- params.lastRequest .sample_steps ,
1249
- params.lastRequest .seed ,
1250
- params.lastRequest .batch_count ,
1251
- NULL ,
1252
- 1 ,
1253
- params.lastRequest .style_ratio ,
1254
- params.lastRequest .normalize_input ,
1255
- params.input_id_images_path .c_str (),
1256
- params.lastRequest .skip_layers .data (),
1257
- params.lastRequest .skip_layers .size (),
1258
- params.lastRequest .slg_scale ,
1259
- params.lastRequest .skip_layer_start ,
1260
- params.lastRequest .skip_layer_end );
1182
+ sd_slg_params_t slg = {
1183
+ params.lastRequest .skip_layers .data (),
1184
+ params.lastRequest .skip_layers .size (),
1185
+ params.lastRequest .skip_layer_start ,
1186
+ params.lastRequest .skip_layer_end ,
1187
+ params.lastRequest .slg_scale };
1188
+ sd_guidance_params_t guidance = {
1189
+ params.lastRequest .cfg_scale ,
1190
+ params.lastRequest .cfg_scale ,
1191
+ params.lastRequest .cfg_scale ,
1192
+ params.lastRequest .guidance ,
1193
+ slg};
1194
+ sd_image_t input_image = {
1195
+ (uint32_t )params.lastRequest .width ,
1196
+ (uint32_t )params.lastRequest .height ,
1197
+ 3 ,
1198
+ NULL };
1199
+ sd_image_t mask_img = input_image;
1200
+ sd_img_gen_params_t gen_params = {
1201
+ params.lastRequest .prompt .c_str (),
1202
+ params.lastRequest .negative_prompt .c_str (),
1203
+ params.lastRequest .clip_skip ,
1204
+ guidance,
1205
+ input_image,
1206
+ NULL , // ref images
1207
+ 0 , // ref images count
1208
+ mask_img,
1209
+ params.lastRequest .width ,
1210
+ params.lastRequest .height ,
1211
+ params.lastRequest .sample_method ,
1212
+ params.lastRequest .sample_steps ,
1213
+ 0 .f , // eta
1214
+ 1 .f , // denoise strength
1215
+ params.lastRequest .seed ,
1216
+ params.lastRequest .batch_count ,
1217
+ NULL , // control image ptr
1218
+ 1 .f , // control strength
1219
+ params.lastRequest .style_ratio ,
1220
+ params.lastRequest .normalize_input ,
1221
+ params.input_id_images_path .c_str ()};
1222
+ results = generate_image (sd_ctx, &gen_params);
1261
1223
1262
1224
if (results == NULL ) {
1263
1225
printf (" generate failed\n " );
@@ -1328,7 +1290,7 @@ void start_server(SDParams params) {
1328
1290
params_json[" guidance" ] = params.lastRequest .guidance ;
1329
1291
params_json[" width" ] = params.lastRequest .width ;
1330
1292
params_json[" height" ] = params.lastRequest .height ;
1331
- params_json[" sample_method" ] = sample_method_str[ params.lastRequest .sample_method ] ;
1293
+ params_json[" sample_method" ] = sd_sample_method_name ( params.lastRequest .sample_method ) ;
1332
1294
params_json[" sample_steps" ] = params.lastRequest .sample_steps ;
1333
1295
params_json[" seed" ] = params.lastRequest .seed ;
1334
1296
params_json[" batch_count" ] = params.lastRequest .batch_count ;
@@ -1352,7 +1314,7 @@ void start_server(SDParams params) {
1352
1314
context_params[" n_threads" ] = params.ctxParams .n_threads ;
1353
1315
context_params[" wtype" ] = params.ctxParams .wtype ;
1354
1316
context_params[" rng_type" ] = params.ctxParams .rng_type ;
1355
- context_params[" schedule" ] = schedule_str[ params.ctxParams .schedule ] ;
1317
+ context_params[" schedule" ] = sd_schedule_name ( params.ctxParams .schedule ) ;
1356
1318
context_params[" clip_on_cpu" ] = params.ctxParams .clip_on_cpu ;
1357
1319
context_params[" control_net_cpu" ] = params.ctxParams .control_net_cpu ;
1358
1320
context_params[" vae_on_cpu" ] = params.ctxParams .vae_on_cpu ;
@@ -1390,17 +1352,17 @@ void start_server(SDParams params) {
1390
1352
svr->Get (" /sample_methods" , [](const httplib::Request& req, httplib::Response& res) {
1391
1353
using json = nlohmann::json;
1392
1354
json response;
1393
- for (int m = 0 ; m < N_SAMPLE_METHODS ; m++) {
1394
- response.push_back (sample_method_str[m] );
1355
+ for (int m = 0 ; m < SAMPLE_METHOD_COUNT ; m++) {
1356
+ response.push_back (sd_sample_method_name (( sample_method_t )m) );
1395
1357
}
1396
1358
res.set_content (response.dump (), " application/json" );
1397
1359
});
1398
1360
1399
1361
svr->Get (" /schedules" , [](const httplib::Request& req, httplib::Response& res) {
1400
1362
using json = nlohmann::json;
1401
1363
json response;
1402
- for (int s = 0 ; s < N_SCHEDULES ; s++) {
1403
- response.push_back (schedule_str[s] );
1364
+ for (int s = 0 ; s < SCHEDULE_COUNT ; s++) {
1365
+ response.push_back (sd_schedule_name (( schedule_t )s) );
1404
1366
}
1405
1367
res.set_content (response.dump (), " application/json" );
1406
1368
});
0 commit comments