Skip to content

Commit 66e121d

Browse files
committed
Added missing prediction type overrides
1 parent 300aa58 commit 66e121d

File tree

4 files changed

+26
-4
lines changed

4 files changed

+26
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ arguments:
333333
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
334334
-b, --batch-count COUNT number of images to generate
335335
--schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
336+
--prediction {eps, v, edm_v, sd3_flow, flux_flow} Prediction type override
336337
--clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
337338
<= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
338339
--vae-tiling process vae in tiles to reduce memory usage

examples/cli/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ void print_usage(int argc, const char* argv[]) {
226226
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
227227
printf(" -b, --batch-count COUNT number of images to generate\n");
228228
printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n");
229-
printf(" --prediction {eps, v, flow} Prediction type.\n");
229+
printf(" --prediction {eps, v, edm_v, sd3_flow, flux_flow} Prediction type override.\n");
230230
printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n");
231231
printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n");
232232
printf(" --vae-tiling process vae in tiles to reduce memory usage\n");

stable-diffusion.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,10 +570,27 @@ class StableDiffusionGGML {
570570
LOG_INFO("running in v-prediction mode");
571571
denoiser = std::make_shared<CompVisVDenoiser>();
572572
break;
573-
case FLOW_PRED:
573+
case EDM_V_PRED:
574+
LOG_INFO("running in v-prediction EDM mode");
575+
denoiser = std::make_shared<EDMVDenoiser>();
576+
break;
577+
case SD3_FLOW_PRED:
574578
LOG_INFO("running in FLOW mode");
575579
denoiser = std::make_shared<DiscreteFlowDenoiser>();
576580
break;
581+
case FLUX_FLOW_PRED:
582+
{
583+
LOG_INFO("running in Flux FLOW mode");
584+
float shift = 1.0f; // TODO: validate
585+
for (auto pair : model_loader.tensor_storages_types) {
586+
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
587+
shift = 1.15f;
588+
break;
589+
}
590+
}
591+
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
592+
break;
593+
}
577594
default:
578595
LOG_ERROR("Unknown parametrization %i", sd_ctx_params->prediction);
579596
abort();
@@ -1313,7 +1330,9 @@ const char* prediction_to_str[] = {
13131330
"default",
13141331
"eps",
13151332
"v",
1316-
"flow",
1333+
"edm_v",
1334+
"sd3_flow",
1335+
"flux_flow",
13171336
};
13181337

13191338
const char* sd_prediction_name(enum prediction_t prediction) {

stable-diffusion.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ enum prediction_t {
6464
DEFAULT_PRED,
6565
EPS_PRED,
6666
V_PRED,
67-
FLOW_PRED,
67+
EDM_V_PRED,
68+
SD3_FLOW_PRED,
69+
FLUX_FLOW_PRED,
6870
PREDICTION_COUNT
6971
};
7072

0 commit comments

Comments
 (0)