Skip to content

Commit 1bef24d

Browse files
committed
Refactor preview to match the other callbacks
1 parent e62e307 commit 1bef24d

File tree

5 files changed

+59
-43
lines changed

5 files changed

+59
-43
lines changed

examples/cli/main.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,10 @@ struct SDParams {
117117
bool chroma_use_t5_mask = false;
118118
int chroma_t5_mask_pad = 1;
119119

120-
sd_preview_policy_t preview_method = SD_PREVIEW_NONE;
121-
int preview_interval = 1;
122-
std::string preview_path = "preview.png";
123-
bool taesd_preview = false;
120+
sd_preview_t preview_method = SD_PREVIEW_NONE;
121+
int preview_interval = 1;
122+
std::string preview_path = "preview.png";
123+
bool taesd_preview = false;
124124
};
125125

126126
void print_params(SDParams params) {
@@ -595,7 +595,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
595595
preview);
596596
return -1;
597597
}
598-
params.preview_method = (sd_preview_policy_t)preview_method;
598+
params.preview_method = (sd_preview_t)preview_method;
599599
return 1;
600600
};
601601

@@ -796,6 +796,7 @@ int main(int argc, const char* argv[]) {
796796
}};
797797

798798
sd_set_log_callback(sd_log_cb, (void*)&params);
799+
sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval);
799800

800801
if (params.verbose) {
801802
print_params(params);
@@ -1018,7 +1019,7 @@ int main(int argc, const char* argv[]) {
10181019
params.input_id_images_path.c_str(),
10191020
};
10201021

1021-
results = generate_image(sd_ctx, &img_gen_params, params.preview_method, params.preview_interval,(step_callback_t)step_callback);
1022+
results = generate_image(sd_ctx, &img_gen_params);
10221023
expected_num_results = params.batch_count;
10231024
} else if (params.mode == VID_GEN) {
10241025
sd_vid_gen_params_t vid_gen_params = {
@@ -1036,7 +1037,7 @@ int main(int argc, const char* argv[]) {
10361037
params.augmentation_level,
10371038
};
10381039

1039-
results = generate_video(sd_ctx, &vid_gen_params, (step_callback_t)step_callback);
1040+
results = generate_video(sd_ctx, &vid_gen_params);
10401041
expected_num_results = params.video_frames;
10411042
}
10421043

stable-diffusion.cpp

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ class StableDiffusionGGML {
858858
int step,
859859
struct ggml_tensor* latents,
860860
enum SDVersion version,
861-
sd_preview_policy_t preview_mode,
861+
sd_preview_t preview_mode,
862862
ggml_tensor* result,
863863
std::function<void(int, sd_image_t)> step_callback) {
864864
const uint32_t channel = 3;
@@ -969,10 +969,7 @@ class StableDiffusionGGML {
969969
int start_merge_step,
970970
SDCondition id_cond,
971971
std::vector<ggml_tensor*> ref_latents = {},
972-
ggml_tensor* denoise_mask = nullptr,
973-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
974-
int preview_interval = 1,
975-
std::function<void(int, sd_image_t)> step_callback = nullptr) {
972+
ggml_tensor* denoise_mask = nullptr) {
976973
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
977974

978975
float cfg_scale = guidance.txt_cfg;
@@ -1034,7 +1031,8 @@ class StableDiffusionGGML {
10341031
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
10351032

10361033
struct ggml_tensor* preview_tensor = NULL;
1037-
if (preview_mode != SD_PREVIEW_NONE && preview_mode != SD_PREVIEW_PROJ) {
1034+
auto sd_preview_mode = sd_get_preview_mode();
1035+
if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) {
10381036
preview_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32,
10391037
(denoised->ne[0] * 8),
10401038
(denoised->ne[1] * 8),
@@ -1216,10 +1214,11 @@ class StableDiffusionGGML {
12161214
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
12171215
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
12181216
}
1219-
1220-
if (step_callback != nullptr) {
1221-
if (step % preview_interval == 0) {
1222-
preview_image(work_ctx, step, denoised, version, preview_mode, preview_tensor, step_callback);
1217+
auto sd_preview_cb = sd_get_preview_callback();
1218+
auto sd_preview_mode = sd_get_preview_mode();
1219+
if (sd_preview_cb != NULL) {
1220+
if (step % sd_get_preview_interval() == 0) {
1221+
preview_image(work_ctx, step, denoised, version, sd_preview_mode, preview_tensor, sd_preview_cb);
12231222
}
12241223
}
12251224
return denoised;
@@ -1671,10 +1670,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
16711670
std::string input_id_images_path,
16721671
std::vector<ggml_tensor*> ref_latents,
16731672
ggml_tensor* concat_latent = NULL,
1674-
ggml_tensor* denoise_mask = NULL,
1675-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
1676-
int preview_interval = 1,
1677-
std::function<void(int, sd_image_t)> step_callback = nullptr) {
1673+
ggml_tensor* denoise_mask = NULL) {
16781674
if (seed < 0) {
16791675
// Generally, when using the provided command line, the seed is always >0.
16801676
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1943,10 +1939,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
19431939
start_merge_step,
19441940
id_cond,
19451941
ref_latents,
1946-
denoise_mask,
1947-
preview_mode,
1948-
preview_interval,
1949-
step_callback);
1942+
denoise_mask);
19501943

19511944
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
19521945
// print_ggml_tensor(x_0);
@@ -2020,7 +2013,7 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
20202013
return init_latent;
20212014
}
20222015

2023-
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params, sd_preview_policy_t preview_mode, int preview_interval, step_callback_t step_callback) {
2016+
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
20242017
int width = sd_img_gen_params->width;
20252018
int height = sd_img_gen_params->height;
20262019
LOG_DEBUG("generate_image %dx%d", width, height);
@@ -2039,7 +2032,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
20392032
if (sd_ctx->sd->stacked_id) {
20402033
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
20412034
}
2042-
if (preview_mode != SD_PREVIEW_NONE && preview_mode != SD_PREVIEW_PROJ) {
2035+
auto sd_preview_mode = sd_get_preview_mode();
2036+
if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) {
20432037
params.mem_size *= 2;
20442038
}
20452039
params.mem_size += width * height * 3 * sizeof(float) * 3;
@@ -2223,10 +2217,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
22232217
sd_img_gen_params->input_id_images_path,
22242218
ref_latents,
22252219
concat_latent,
2226-
denoise_mask,
2227-
preview_mode,
2228-
preview_interval,
2229-
step_callback);
2220+
denoise_mask);
22302221

22312222
size_t t2 = ggml_time_ms();
22322223

@@ -2235,7 +2226,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
22352226
return result_images;
22362227
}
22372228

2238-
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, step_callback_t step_callback) {
2229+
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) {
22392230
if (sd_ctx == NULL || sd_vid_gen_params == NULL) {
22402231
return NULL;
22412232
}
@@ -2319,8 +2310,6 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
23192310
-1,
23202311
SDCondition(NULL, NULL, NULL),
23212312
{},
2322-
NULL,
2323-
(sd_preview_policy_t)0, 1,
23242313
NULL);
23252314

23262315
int64_t t2 = ggml_time_ms();

stable-diffusion.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ enum sd_log_level_t {
111111
SD_LOG_ERROR
112112
};
113113

114-
enum sd_preview_policy_t {
114+
enum sd_preview_t {
115115
SD_PREVIEW_NONE,
116116
SD_PREVIEW_PROJ,
117117
SD_PREVIEW_TAE,
@@ -214,11 +214,11 @@ typedef struct sd_ctx_t sd_ctx_t;
214214

215215
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
216216
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
217+
typedef void (*sd_preview_cb_t)(int, sd_image_t);
217218

218219
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
219220
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
220-
SD_API sd_progress_cb_t sd_get_progress_callback();
221-
SD_API void* sd_get_progress_callback_data();
221+
SD_API void sd_set_preview_callback(sd_preview_cb_t cb, sd_preview_t mode, int interval);
222222
SD_API int32_t get_num_physical_cores();
223223
SD_API const char* sd_get_system_info();
224224

@@ -237,14 +237,12 @@ SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
237237
SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params);
238238
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
239239

240-
typedef void (*step_callback_t)(int, sd_image_t);
241-
242240
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
243241
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
244-
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params, sd_preview_policy_t preview_mode, int preview_interval, step_callback_t step_callback);
242+
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);
245243

246244
SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
247-
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, step_callback_t step_callback); // broken
245+
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params); // broken
248246

249247
typedef struct upscaler_ctx_t upscaler_ctx_t;
250248

util.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ int32_t get_num_physical_cores() {
247247
static sd_progress_cb_t sd_progress_cb = NULL;
248248
void* sd_progress_cb_data = NULL;
249249

250+
static sd_preview_cb_t sd_preview_cb = NULL;
251+
sd_preview_t sd_preview_mode = SD_PREVIEW_NONE;
252+
int sd_preview_interval = 1;
253+
250254
std::u32string utf8_to_utf32(const std::string& utf8_str) {
251255
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
252256
return converter.from_bytes(utf8_str);
@@ -420,10 +424,27 @@ void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
420424
sd_progress_cb = cb;
421425
sd_progress_cb_data = data;
422426
}
423-
sd_progress_cb_t sd_get_progress_callback(){
427+
void sd_set_preview_callback(sd_preview_cb_t cb, sd_preview_t mode = SD_PREVIEW_PROJ, int interval = 1) {
428+
sd_preview_cb = cb;
429+
sd_preview_mode = mode;
430+
sd_preview_interval = interval;
431+
}
432+
433+
sd_preview_cb_t sd_get_preview_callback() {
434+
return sd_preview_cb;
435+
}
436+
437+
sd_preview_t sd_get_preview_mode() {
438+
return sd_preview_mode;
439+
}
440+
int sd_get_preview_interval() {
441+
return sd_preview_interval;
442+
}
443+
444+
sd_progress_cb_t sd_get_progress_callback() {
424445
return sd_progress_cb;
425446
}
426-
void* sd_get_progress_callback_data(){
447+
void* sd_get_progress_callback_data() {
427448
return sd_progress_cb_data;
428449
}
429450
const char* sd_get_system_info() {

util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ std::string trim(const std::string& s);
5757

5858
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text);
5959

60+
sd_progress_cb_t sd_get_progress_callback();
61+
void* sd_get_progress_callback_data();
62+
63+
sd_preview_cb_t sd_get_preview_callback();
64+
sd_preview_t sd_get_preview_mode();
65+
int sd_get_preview_interval();
66+
6067
#define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
6168
#define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
6269
#define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)

0 commit comments

Comments
 (0)