Skip to content

Commit 300aa58

Browse files
committed
Added prediction argument
1 parent 1896b28 commit 300aa58

File tree

3 files changed

+91
-16
lines changed

3 files changed

+91
-16
lines changed

examples/cli/main.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ struct SDParams {
8585

8686
sample_method_t sample_method = EULER_A;
8787
schedule_t schedule = DEFAULT;
88+
prediction_t prediction = DEFAULT_PRED;
8889
int sample_steps = 20;
8990
float strength = 0.75f;
9091
float control_strength = 0.9f;
@@ -156,6 +157,7 @@ void print_params(SDParams params) {
156157
printf(" height: %d\n", params.height);
157158
printf(" sample_method: %s\n", sd_sample_method_name(params.sample_method));
158159
printf(" schedule: %s\n", sd_schedule_name(params.schedule));
160+
printf(" prediction: %s\n", sd_prediction_name(params.prediction));
159161
printf(" sample_steps: %d\n", params.sample_steps);
160162
printf(" strength(img2img): %.2f\n", params.strength);
161163
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
@@ -224,6 +226,7 @@ void print_usage(int argc, const char* argv[]) {
224226
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
225227
printf(" -b, --batch-count COUNT number of images to generate\n");
226228
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
229+
printf(" --prediction {eps, v, flow} Prediction type.\n");
227230
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
228231
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
229232
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");
@@ -494,6 +497,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
494497
return 1;
495498
};
496499

500+
auto on_prediction_arg = [&](int argc, const char** argv, int index) {
501+
if (++index >= argc) {
502+
return -1;
503+
}
504+
const char* arg = argv[index];
505+
params.prediction = str_to_prediction(arg);
506+
if (params.prediction == PREDICTION_COUNT) {
507+
fprintf(stderr, "error: invalid prediction type %s\n",
508+
arg);
509+
return -1;
510+
}
511+
return 1;
512+
};
513+
497514
auto on_sample_method_arg = [&](int argc, const char** argv, int index) {
498515
if (++index >= argc) {
499516
return -1;
@@ -564,6 +581,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
564581
{"-s", "--seed", "", on_seed_arg},
565582
{"", "--sampling-method", "", on_sample_method_arg},
566583
{"", "--schedule", "", on_schedule_arg},
584+
{"", "--prediction", "", on_prediction_arg},
567585
{"", "--skip-layers", "", on_skip_layers_arg},
568586
{"-r", "--ref-image", "", on_ref_image_arg},
569587
{"-h", "--help", "", on_help_arg},
@@ -883,6 +901,7 @@ int main(int argc, const char* argv[]) {
883901
params.wtype,
884902
params.rng_type,
885903
params.schedule,
904+
params.prediction,
886905
params.clip_on_cpu,
887906
params.control_net_cpu,
888907
params.vae_on_cpu,

stable-diffusion.cpp

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -561,25 +561,44 @@ class StableDiffusionGGML {
561561
int64_t t1 = ggml_time_ms();
562562
LOG_INFO("loading model from '%s' completed, taking %.2fs", SAFE_STR(sd_ctx_params->model_path), (t1 - t0) * 1.0f / 1000);
563563

564-
// check is_using_v_parameterization_for_sd2
565-
566-
if (sd_version_is_sd2(version)) {
567-
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
568-
is_using_v_parameterization = true;
569-
}
570-
} else if (sd_version_is_sdxl(version)) {
571-
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
572-
// CosXL models
573-
// TODO: get sigma_min and sigma_max values from file
574-
is_using_edm_v_parameterization = true;
564+
if (sd_ctx_params->prediction != DEFAULT_PRED) {
565+
switch (sd_ctx_params->prediction) {
566+
case EPS_PRED:
567+
LOG_INFO("running in eps-prediction mode");
568+
break;
569+
case V_PRED:
570+
LOG_INFO("running in v-prediction mode");
571+
denoiser = std::make_shared<CompVisVDenoiser>();
572+
break;
573+
case FLOW_PRED:
574+
LOG_INFO("running in FLOW mode");
575+
denoiser = std::make_shared<DiscreteFlowDenoiser>();
576+
break;
577+
default:
578+
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
579+
abort();
575580
}
576-
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
581+
} else {
582+
// check is_using_v_parameterization_for_sd2
583+
584+
if (sd_version_is_sd2(version)) {
585+
if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) {
586+
is_using_v_parameterization = true;
587+
}
588+
} else if (sd_version_is_sdxl(version)) {
589+
if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) {
590+
// CosXL models
591+
// TODO: get sigma_min and sigma_max values from file
592+
is_using_edm_v_parameterization = true;
593+
}
594+
if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) {
595+
is_using_v_parameterization = true;
596+
}
597+
} else if (version == VERSION_SVD) {
598+
// TODO: V_PREDICTION_EDM
577599
is_using_v_parameterization = true;
578600
}
579-
} else if (version == VERSION_SVD) {
580-
// TODO: V_PREDICTION_EDM
581-
is_using_v_parameterization = true;
582-
}
601+
}
583602

584603
if (sd_version_is_sd3(version)) {
585604
LOG_INFO("running in FLOW mode");
@@ -1290,6 +1309,29 @@ enum schedule_t str_to_schedule(const char* str) {
12901309
return SCHEDULE_COUNT;
12911310
}
12921311

1312+
const char* prediction_to_str[] = {
1313+
"default",
1314+
"eps",
1315+
"v",
1316+
"flow",
1317+
};
1318+
1319+
const char* sd_prediction_name(enum prediction_t prediction) {
1320+
if (prediction < PREDICTION_COUNT) {
1321+
return prediction_to_str[prediction];
1322+
}
1323+
return NONE_STR;
1324+
}
1325+
1326+
enum prediction_t str_to_prediction(const char* str) {
1327+
for (int i = 0; i < PREDICTION_COUNT; i++) {
1328+
if (!strcmp(str, prediction_to_str[i])) {
1329+
return (enum prediction_t)i;
1330+
}
1331+
}
1332+
return PREDICTION_COUNT;
1333+
}
1334+
12931335
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
12941336
memset((void*)sd_ctx_params, 0, sizeof(sd_ctx_params_t));
12951337
sd_ctx_params->vae_decode_only = true;
@@ -1299,6 +1341,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
12991341
sd_ctx_params->wtype = SD_TYPE_COUNT;
13001342
sd_ctx_params->rng_type = CUDA_RNG;
13011343
sd_ctx_params->schedule = DEFAULT;
1344+
sd_ctx_params->prediction = DEFAULT_PRED;
13021345
sd_ctx_params->keep_clip_on_cpu = false;
13031346
sd_ctx_params->keep_control_net_on_cpu = false;
13041347
sd_ctx_params->keep_vae_on_cpu = false;
@@ -1333,6 +1376,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
13331376
"wtype: %s\n"
13341377
"rng_type: %s\n"
13351378
"schedule: %s\n"
1379+
"prediction: %s\n"
13361380
"keep_clip_on_cpu: %s\n"
13371381
"keep_control_net_on_cpu: %s\n"
13381382
"keep_vae_on_cpu: %s\n"
@@ -1358,6 +1402,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
13581402
sd_type_name(sd_ctx_params->wtype),
13591403
sd_rng_type_name(sd_ctx_params->rng_type),
13601404
sd_schedule_name(sd_ctx_params->schedule),
1405+
sd_prediction_name(sd_ctx_params->prediction),
13611406
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
13621407
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
13631408
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),

stable-diffusion.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ enum schedule_t {
6060
SCHEDULE_COUNT
6161
};
6262

63+
enum prediction_t {
64+
DEFAULT_PRED,
65+
EPS_PRED,
66+
V_PRED,
67+
FLOW_PRED,
68+
PREDICTION_COUNT
69+
};
70+
6371
// same as enum ggml_type
6472
enum sd_type_t {
6573
SD_TYPE_F32 = 0,
@@ -130,6 +138,7 @@ typedef struct {
130138
enum sd_type_t wtype;
131139
enum rng_type_t rng_type;
132140
enum schedule_t schedule;
141+
enum prediction_t prediction;
133142
bool keep_clip_on_cpu;
134143
bool keep_control_net_on_cpu;
135144
bool keep_vae_on_cpu;
@@ -219,6 +228,8 @@ SD_API const char* sd_sample_method_name(enum sample_method_t sample_method);
219228
SD_API enum sample_method_t str_to_sample_method(const char* str);
220229
SD_API const char* sd_schedule_name(enum schedule_t schedule);
221230
SD_API enum schedule_t str_to_schedule(const char* str);
231+
SD_API const char* sd_prediction_name(enum prediction_t prediction);
232+
SD_API enum prediction_t str_to_prediction(const char* str);
222233

223234
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
224235
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);

0 commit comments

Comments
 (0)