Skip to content

Commit 60305ce

Browse files
committed
integrate apg into new API
1 parent 0ce2c91 commit 60305ce

File tree

3 files changed

+36
-31
lines changed

3 files changed

+36
-31
lines changed

examples/cli/main.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -764,17 +764,25 @@ int main(int argc, const char* argv[]) {
764764

765765
parse_args(argc, argv, params);
766766

767-
sd_guidance_params_t guidance_params = {params.cfg_scale,
768-
params.img_cfg_scale,
769-
params.min_cfg,
770-
params.guidance,
771-
{
772-
params.skip_layers.data(),
773-
params.skip_layers.size(),
774-
params.skip_layer_start,
775-
params.skip_layer_end,
776-
params.slg_scale,
777-
}};
767+
sd_guidance_params_t guidance_params = {
768+
params.cfg_scale,
769+
params.img_cfg_scale,
770+
params.min_cfg,
771+
params.guidance,
772+
{
773+
params.skip_layers.data(),
774+
params.skip_layers.size(),
775+
params.skip_layer_start,
776+
params.skip_layer_end,
777+
params.slg_scale,
778+
},
779+
{
780+
params.apg_eta,
781+
params.apg_momentum,
782+
params.apg_norm_threshold,
783+
params.apg_norm_smoothing,
784+
},
785+
};
778786

779787
sd_set_log_callback(sd_log_cb, (void*)&params);
780788

@@ -998,7 +1006,7 @@ int main(int argc, const char* argv[]) {
9981006
params.input_id_images_path.c_str(),
9991007
};
10001008

1001-
results = generate_image(sd_ctx, &img_gen_params, {params.apg_eta, params.apg_momentum, params.apg_norm_threshold});
1009+
results = generate_image(sd_ctx, &img_gen_params);
10021010
expected_num_results = params.batch_count;
10031011
} else if (params.mode == VID_GEN) {
10041012
sd_vid_gen_params_t vid_gen_params = {

stable-diffusion.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,6 @@ class StableDiffusionGGML {
848848
int start_merge_step,
849849
SDCondition id_cond,
850850
std::vector<ggml_tensor*> ref_latents = {},
851-
sd_apg_params_t apg_params = {1, 0, 0, 0},
852851
ggml_tensor* denoise_mask = nullptr) {
853852
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
854853

@@ -913,7 +912,7 @@ class StableDiffusionGGML {
913912
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
914913

915914
std::vector<float> apg_momentum_buffer;
916-
if (apg_params.momentum != 0)
915+
if (guidance.apg.momentum != 0)
917916
apg_momentum_buffer.resize((size_t)ggml_nelements(denoised));
918917

919918
auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
@@ -1096,14 +1095,14 @@ class StableDiffusionGGML {
10961095
// classic CFG (img_cfg_scale == cfg_scale != 1)
10971096
delta = positive_data[i] - negative_data[i];
10981097
}
1099-
if (apg_params.momentum != 0) {
1100-
delta += apg_params.momentum * apg_momentum_buffer[i];
1098+
if (guidance.apg.momentum != 0) {
1099+
delta += guidance.apg.momentum * apg_momentum_buffer[i];
11011100
apg_momentum_buffer[i] = delta;
11021101
}
1103-
if (apg_params.norm_treshold > 0 || log_cfg_norm) {
1102+
if (guidance.apg.norm_treshold > 0 || log_cfg_norm) {
11041103
diff_norm += delta * delta;
11051104
}
1106-
if (apg_params.eta != 1.0f) {
1105+
if (guidance.apg.eta != 1.0f) {
11071106
cond_norm_sq += positive_data[i] * positive_data[i];
11081107
dot += positive_data[i] * delta;
11091108
}
@@ -1112,30 +1111,30 @@ class StableDiffusionGGML {
11121111
if (log_cfg_norm) {
11131112
LOG_INFO("CFG Delta norm: %.2f", sqrtf(diff_norm));
11141113
}
1115-
if (apg_params.norm_treshold > 0) {
1114+
if (guidance.apg.norm_treshold > 0) {
11161115
diff_norm = sqrtf(diff_norm);
1117-
if (apg_params.norm_treshold_smoothing <= 0) {
1118-
apg_scale_factor = std::min(1.0f, apg_params.norm_treshold / diff_norm);
1116+
if (guidance.apg.norm_treshold_smoothing <= 0) {
1117+
apg_scale_factor = std::min(1.0f, guidance.apg.norm_treshold / diff_norm);
11191118
} else {
11201119
// Experimental: smooth saturate
1121-
float x = apg_params.norm_treshold / diff_norm;
1122-
apg_scale_factor = x / std::pow(1 + std::pow(x, 1.0 / apg_params.norm_treshold_smoothing), apg_params.norm_treshold_smoothing);
1120+
float x = guidance.apg.norm_treshold / diff_norm;
1121+
apg_scale_factor = x / std::pow(1 + std::pow(x, 1.0 / guidance.apg.norm_treshold_smoothing), guidance.apg.norm_treshold_smoothing);
11231122
}
11241123
}
1125-
if (apg_params.eta != 1.0f) {
1124+
if (guidance.apg.eta != 1.0f) {
11261125
dot *= apg_scale_factor;
11271126
// pre-normalize (avoids one square root and ne_elements extra divs)
11281127
dot /= cond_norm_sq;
11291128
}
11301129

11311130
for (int i = 0; i < ne_elements; i++) {
11321131
deltas[i] *= apg_scale_factor;
1133-
if (apg_params.eta != 1.0f) {
1132+
if (guidance.apg.eta != 1.0f) {
11341133
float apg_parallel = dot * positive_data[i];
11351134
float apg_orthogonal = deltas[i] - apg_parallel;
11361135

11371136
// tweak deltas
1138-
deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
1137+
deltas[i] = apg_orthogonal + guidance.apg.eta * apg_parallel;
11391138
}
11401139
}
11411140
}
@@ -1636,7 +1635,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
16361635
std::string input_id_images_path,
16371636
std::vector<ggml_tensor*> ref_latents,
16381637
ggml_tensor* concat_latent = NULL,
1639-
sd_apg_params_t apg_params = {},
16401638
ggml_tensor* denoise_mask = NULL) {
16411639
if (seed < 0) {
16421640
// Generally, when using the provided command line, the seed is always >0.
@@ -1906,7 +1904,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
19061904
start_merge_step,
19071905
id_cond,
19081906
ref_latents,
1909-
apg_params,
19101907
denoise_mask);
19111908

19121909
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
@@ -1981,7 +1978,7 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
19811978
return init_latent;
19821979
}
19831980

1984-
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params, sd_apg_params_t apg_params) {
1981+
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
19851982
int width = sd_img_gen_params->width;
19861983
int height = sd_img_gen_params->height;
19871984
LOG_DEBUG("generate_image %dx%d", width, height);
@@ -2181,7 +2178,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
21812178
sd_img_gen_params->input_id_images_path,
21822179
ref_latents,
21832180
concat_latent,
2184-
apg_params,
21852181
denoise_mask);
21862182

21872183
size_t t2 = ggml_time_ms();

stable-diffusion.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ typedef struct {
168168
float min_cfg;
169169
float distilled_guidance;
170170
sd_slg_params_t slg;
171+
sd_apg_params_t apg;
171172
} sd_guidance_params_t;
172173

173174
typedef struct {
@@ -236,7 +237,7 @@ SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
236237

237238
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
238239
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
239-
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params, sd_apg_params_t apg_params);
240+
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
240241

241242
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
242243
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params); // broken

0 commit comments

Comments
 (0)