@@ -561,25 +561,44 @@ class StableDiffusionGGML {
561
561
int64_t t1 = ggml_time_ms ();
562
562
LOG_INFO (" loading model from '%s' completed, taking %.2fs" , SAFE_STR (sd_ctx_params->model_path ), (t1 - t0) * 1 .0f / 1000 );
563
563
564
- // check is_using_v_parameterization_for_sd2
565
-
566
- if (sd_version_is_sd2 (version)) {
567
- if (is_using_v_parameterization_for_sd2 (ctx, sd_version_is_inpaint (version))) {
568
- is_using_v_parameterization = true ;
569
- }
570
- } else if (sd_version_is_sdxl (version)) {
571
- if (model_loader.tensor_storages_types .find (" edm_vpred.sigma_max" ) != model_loader.tensor_storages_types .end ()) {
572
- // CosXL models
573
- // TODO: get sigma_min and sigma_max values from file
574
- is_using_edm_v_parameterization = true ;
564
+ if (sd_ctx_params->prediction != DEFAULT_PRED) {
565
+ switch (sd_ctx_params->prediction ) {
566
+ case EPS_PRED:
567
+ LOG_INFO (" running in eps-prediction mode" );
568
+ break ;
569
+ case V_PRED:
570
+ LOG_INFO (" running in v-prediction mode" );
571
+ denoiser = std::make_shared<CompVisVDenoiser>();
572
+ break ;
573
+ case FLOW_PRED:
574
+ LOG_INFO (" running in FLOW mode" );
575
+ denoiser = std::make_shared<DiscreteFlowDenoiser>();
576
+ break ;
577
+ default :
578
+ LOG_ERROR (" Unknown parametrization %i" , sd_ctx_params->prediction );
579
+ abort ();
575
580
}
576
- if (model_loader.tensor_storages_types .find (" v_pred" ) != model_loader.tensor_storages_types .end ()) {
581
+ } else {
582
+ // check is_using_v_parameterization_for_sd2
583
+
584
+ if (sd_version_is_sd2 (version)) {
585
+ if (is_using_v_parameterization_for_sd2 (ctx, sd_version_is_inpaint (version))) {
586
+ is_using_v_parameterization = true ;
587
+ }
588
+ } else if (sd_version_is_sdxl (version)) {
589
+ if (model_loader.tensor_storages_types .find (" edm_vpred.sigma_max" ) != model_loader.tensor_storages_types .end ()) {
590
+ // CosXL models
591
+ // TODO: get sigma_min and sigma_max values from file
592
+ is_using_edm_v_parameterization = true ;
593
+ }
594
+ if (model_loader.tensor_storages_types .find (" v_pred" ) != model_loader.tensor_storages_types .end ()) {
595
+ is_using_v_parameterization = true ;
596
+ }
597
+ } else if (version == VERSION_SVD) {
598
+ // TODO: V_PREDICTION_EDM
577
599
is_using_v_parameterization = true ;
578
600
}
579
- } else if (version == VERSION_SVD) {
580
- // TODO: V_PREDICTION_EDM
581
- is_using_v_parameterization = true ;
582
- }
601
+ }
583
602
584
603
if (sd_version_is_sd3 (version)) {
585
604
LOG_INFO (" running in FLOW mode" );
@@ -1290,6 +1309,29 @@ enum schedule_t str_to_schedule(const char* str) {
1290
1309
return SCHEDULE_COUNT;
1291
1310
}
1292
1311
1312
+ const char * prediction_to_str[] = {
1313
+ " default" ,
1314
+ " eps" ,
1315
+ " v" ,
1316
+ " flow" ,
1317
+ };
1318
+
1319
+ const char * sd_prediction_name (enum prediction_t prediction) {
1320
+ if (prediction < PREDICTION_COUNT) {
1321
+ return prediction_to_str[prediction];
1322
+ }
1323
+ return NONE_STR;
1324
+ }
1325
+
1326
+ enum prediction_t str_to_prediction (const char * str) {
1327
+ for (int i = 0 ; i < PREDICTION_COUNT; i++) {
1328
+ if (!strcmp (str, prediction_to_str[i])) {
1329
+ return (enum prediction_t )i;
1330
+ }
1331
+ }
1332
+ return PREDICTION_COUNT;
1333
+ }
1334
+
1293
1335
void sd_ctx_params_init (sd_ctx_params_t * sd_ctx_params) {
1294
1336
memset ((void *)sd_ctx_params, 0 , sizeof (sd_ctx_params_t ));
1295
1337
sd_ctx_params->vae_decode_only = true ;
@@ -1299,6 +1341,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
1299
1341
sd_ctx_params->wtype = SD_TYPE_COUNT;
1300
1342
sd_ctx_params->rng_type = CUDA_RNG;
1301
1343
sd_ctx_params->schedule = DEFAULT;
1344
+ sd_ctx_params->prediction = DEFAULT_PRED;
1302
1345
sd_ctx_params->keep_clip_on_cpu = false ;
1303
1346
sd_ctx_params->keep_control_net_on_cpu = false ;
1304
1347
sd_ctx_params->keep_vae_on_cpu = false ;
@@ -1333,6 +1376,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
1333
1376
" wtype: %s\n "
1334
1377
" rng_type: %s\n "
1335
1378
" schedule: %s\n "
1379
+ " prediction: %s\n "
1336
1380
" keep_clip_on_cpu: %s\n "
1337
1381
" keep_control_net_on_cpu: %s\n "
1338
1382
" keep_vae_on_cpu: %s\n "
@@ -1358,6 +1402,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
1358
1402
sd_type_name (sd_ctx_params->wtype ),
1359
1403
sd_rng_type_name (sd_ctx_params->rng_type ),
1360
1404
sd_schedule_name (sd_ctx_params->schedule ),
1405
+ sd_prediction_name (sd_ctx_params->prediction ),
1361
1406
BOOL_STR (sd_ctx_params->keep_clip_on_cpu ),
1362
1407
BOOL_STR (sd_ctx_params->keep_control_net_on_cpu ),
1363
1408
BOOL_STR (sd_ctx_params->keep_vae_on_cpu ),
0 commit comments