Skip to content

Commit f03d84c

Browse files
committed
Refactor preview to match the other callbacks
1 parent a976e74 commit f03d84c

File tree

5 files changed

+84
-88
lines changed

5 files changed

+84
-88
lines changed

examples/cli/main.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,10 @@ struct SDParams {
146146
bool chroma_use_t5_mask = false;
147147
int chroma_t5_mask_pad = 1;
148148

149-
sd_preview_policy_t preview_method = SD_PREVIEW_NONE;
150-
int preview_interval = 1;
151-
std::string preview_path = "preview.png";
152-
bool taesd_preview = false;
149+
sd_preview_t preview_method = SD_PREVIEW_NONE;
150+
int preview_interval = 1;
151+
std::string preview_path = "preview.png";
152+
bool taesd_preview = false;
153153
};
154154

155155
void print_params(SDParams params) {
@@ -713,7 +713,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
713713
invalid_arg = true;
714714
break;
715715
}
716-
params.preview_method = (sd_preview_policy_t)preview_method;
716+
params.preview_method = (sd_preview_t)preview_method;
717717
} else if (arg == "--preview-interval") {
718718
if (++i >= argc) {
719719
invalid_arg = true;
@@ -907,6 +907,7 @@ int main(int argc, const char* argv[]) {
907907
preview_path = params.preview_path.c_str();
908908

909909
sd_set_log_callback(sd_log_cb, (void*)&params);
910+
sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval);
910911

911912
if (params.verbose) {
912913
print_params(params);
@@ -1117,10 +1118,7 @@ int main(int argc, const char* argv[]) {
11171118
params.skip_layers.size(),
11181119
params.slg_scale,
11191120
params.skip_layer_start,
1120-
params.skip_layer_end,
1121-
params.preview_method,
1122-
params.preview_interval,
1123-
(step_callback_t)step_callback);
1121+
params.skip_layer_end);
11241122
} else if (params.mode == IMG2IMG || params.mode == IMG2VID) {
11251123
sd_image_t input_image = {(uint32_t)params.width,
11261124
(uint32_t)params.height,
@@ -1189,10 +1187,7 @@ int main(int argc, const char* argv[]) {
11891187
params.skip_layers.size(),
11901188
params.slg_scale,
11911189
params.skip_layer_start,
1192-
params.skip_layer_end,
1193-
params.preview_method,
1194-
params.preview_interval,
1195-
(step_callback_t)step_callback);
1190+
params.skip_layer_end);
11961191
}
11971192
} else { // EDIT
11981193
results = edit(sd_ctx,

stable-diffusion.cpp

Lines changed: 33 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ class StableDiffusionGGML {
845845
int step,
846846
struct ggml_tensor* latents,
847847
enum SDVersion version,
848-
sd_preview_policy_t preview_mode,
848+
sd_preview_t preview_mode,
849849
ggml_tensor* result,
850850
std::function<void(int, sd_image_t)> step_callback) {
851851
const uint32_t channel = 3;
@@ -958,14 +958,11 @@ class StableDiffusionGGML {
958958
int start_merge_step,
959959
SDCondition id_cond,
960960
std::vector<ggml_tensor*> ref_latents = {},
961-
std::vector<int> skip_layers = {},
962-
float slg_scale = 0,
963-
float skip_layer_start = 0.01,
964-
float skip_layer_end = 0.2,
965-
ggml_tensor* noise_mask = nullptr,
966-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
967-
int preview_interval = 1,
968-
std::function<void(int, sd_image_t)> step_callback = nullptr) {
961+
std::vector<int> skip_layers = {},
962+
float slg_scale = 0,
963+
float skip_layer_start = 0.01,
964+
float skip_layer_end = 0.2,
965+
ggml_tensor* noise_mask = nullptr) {
969966
size_t steps = sigmas.size() - 1;
970967
// noise = load_tensor_from_file(work_ctx, "./rand0.bin");
971968
// print_ggml_tensor(noise);
@@ -997,7 +994,8 @@ class StableDiffusionGGML {
997994
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
998995

999996
struct ggml_tensor* preview_tensor = NULL;
1000-
if (preview_mode != SD_PREVIEW_NONE && preview_mode != SD_PREVIEW_PROJ) {
997+
auto sd_preview_mode = sd_get_preview_mode();
998+
if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) {
1001999
preview_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32,
10021000
(denoised->ne[0] * 8),
10031001
(denoised->ne[1] * 8),
@@ -1149,10 +1147,11 @@ class StableDiffusionGGML {
11491147
pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f);
11501148
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
11511149
}
1152-
1153-
if (step_callback != nullptr) {
1154-
if (step % preview_interval == 0) {
1155-
preview_image(work_ctx, step, denoised, version, preview_mode, preview_tensor, step_callback);
1150+
auto sd_preview_cb = sd_get_preview_callback();
1151+
auto sd_preview_mode = sd_get_preview_mode();
1152+
if (sd_preview_cb != NULL) {
1153+
if (step % sd_get_preview_interval() == 0) {
1154+
preview_image(work_ctx, step, denoised, version, sd_preview_mode, preview_tensor, sd_preview_cb);
11561155
}
11571156
}
11581157
return denoised;
@@ -1385,14 +1384,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
13851384
bool normalize_input,
13861385
std::string input_id_images_path,
13871386
std::vector<ggml_tensor*> ref_latents,
1388-
std::vector<int> skip_layers = {},
1389-
float slg_scale = 0,
1390-
float skip_layer_start = 0.01,
1391-
float skip_layer_end = 0.2,
1392-
ggml_tensor* masked_image = NULL,
1393-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
1394-
int preview_interval = 1,
1395-
std::function<void(int, sd_image_t)> step_callback = nullptr) {
1387+
std::vector<int> skip_layers = {},
1388+
float slg_scale = 0,
1389+
float skip_layer_start = 0.01,
1390+
float skip_layer_end = 0.2,
1391+
ggml_tensor* masked_image = NULL) {
13961392
if (seed < 0) {
13971393
// Generally, when using the provided command line, the seed is always >0.
13981394
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@@ -1650,10 +1646,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
16501646
slg_scale,
16511647
skip_layer_start,
16521648
skip_layer_end,
1653-
noise_mask,
1654-
preview_mode,
1655-
preview_interval,
1656-
step_callback);
1649+
noise_mask);
16571650

16581651
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
16591652
// print_ggml_tensor(x_0);
@@ -1745,14 +1738,11 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
17451738
float style_ratio,
17461739
bool normalize_input,
17471740
const char* input_id_images_path_c_str,
1748-
int* skip_layers = NULL,
1749-
size_t skip_layers_count = 0,
1750-
float slg_scale = 0,
1751-
float skip_layer_start = 0.01,
1752-
float skip_layer_end = 0.2,
1753-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
1754-
int preview_interval = 1,
1755-
step_callback_t step_callback = NULL) {
1741+
int* skip_layers = NULL,
1742+
size_t skip_layers_count = 0,
1743+
float slg_scale = 0,
1744+
float skip_layer_start = 0.01,
1745+
float skip_layer_end = 0.2) {
17561746
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
17571747
LOG_DEBUG("txt2img %dx%d", width, height);
17581748
if (sd_ctx == NULL) {
@@ -1770,7 +1760,8 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
17701760
if (sd_ctx->sd->stacked_id) {
17711761
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
17721762
}
1773-
if (preview_mode != SD_PREVIEW_NONE && preview_mode != SD_PREVIEW_PROJ) {
1763+
auto sd_preview_mode = sd_get_preview_mode();
1764+
if (sd_preview_mode != SD_PREVIEW_NONE && sd_preview_mode != SD_PREVIEW_PROJ) {
17741765
params.mem_size *= 2;
17751766
}
17761767
params.mem_size += width * height * 3 * sizeof(float);
@@ -1820,10 +1811,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
18201811
slg_scale,
18211812
skip_layer_start,
18221813
skip_layer_end,
1823-
NULL,
1824-
preview_mode,
1825-
preview_interval,
1826-
step_callback);
1814+
NULL);
18271815

18281816
size_t t1 = ggml_time_ms();
18291817

@@ -1853,14 +1841,11 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18531841
float style_ratio,
18541842
bool normalize_input,
18551843
const char* input_id_images_path_c_str,
1856-
int* skip_layers = NULL,
1857-
size_t skip_layers_count = 0,
1858-
float slg_scale = 0,
1859-
float skip_layer_start = 0.01,
1860-
float skip_layer_end = 0.2,
1861-
sd_preview_policy_t preview_mode = SD_PREVIEW_NONE,
1862-
int preview_interval = 1,
1863-
step_callback_t step_callback = NULL) {
1844+
int* skip_layers = NULL,
1845+
size_t skip_layers_count = 0,
1846+
float slg_scale = 0,
1847+
float skip_layer_start = 0.01,
1848+
float skip_layer_end = 0.2) {
18641849
std::vector<int> skip_layers_vec(skip_layers, skip_layers + skip_layers_count);
18651850
LOG_DEBUG("img2img %dx%d", width, height);
18661851
if (sd_ctx == NULL) {
@@ -2008,10 +1993,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
20081993
slg_scale,
20091994
skip_layer_start,
20101995
skip_layer_end,
2011-
masked_image,
2012-
preview_mode,
2013-
preview_interval,
2014-
step_callback);
1996+
masked_image);
20151997

20161998
size_t t2 = ggml_time_ms();
20171999

@@ -2117,8 +2099,6 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
21172099
{},
21182100
{},
21192101
0, 0, 0,
2120-
NULL,
2121-
(sd_preview_policy_t)0, 1,
21222102
NULL);
21232103

21242104
int64_t t2 = ggml_time_ms();

stable-diffusion.h

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -112,31 +112,32 @@ enum sd_log_level_t {
112112
SD_LOG_ERROR
113113
};
114114

115-
enum sd_preview_policy_t {
115+
enum sd_preview_t {
116116
SD_PREVIEW_NONE,
117117
SD_PREVIEW_PROJ,
118118
SD_PREVIEW_TAE,
119119
SD_PREVIEW_VAE,
120120
N_PREVIEWS
121121
};
122122

123+
typedef struct {
124+
uint32_t width;
125+
uint32_t height;
126+
uint32_t channel;
127+
uint8_t* data;
128+
} sd_image_t;
129+
123130
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
124131
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
132+
typedef void (*sd_preview_cb_t)(int, sd_image_t);
133+
125134

126135
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
127136
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
128-
SD_API sd_progress_cb_t sd_get_progress_callback();
129-
SD_API void* sd_get_progress_callback_data();
137+
SD_API void sd_set_preview_callback(sd_preview_cb_t cb, sd_preview_t mode, int interval);
130138
SD_API int32_t get_num_physical_cores();
131139
SD_API const char* sd_get_system_info();
132140

133-
typedef struct {
134-
uint32_t width;
135-
uint32_t height;
136-
uint32_t channel;
137-
uint8_t* data;
138-
} sd_image_t;
139-
140141
typedef struct sd_ctx_t sd_ctx_t;
141142

142143
SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
@@ -168,8 +169,6 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
168169

169170
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
170171

171-
typedef void (*step_callback_t)(int, sd_image_t);
172-
173172
SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
174173
const char* prompt,
175174
const char* negative_prompt,
@@ -192,10 +191,7 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
192191
size_t skip_layers_count,
193192
float slg_scale,
194193
float skip_layer_start,
195-
float skip_layer_end,
196-
sd_preview_policy_t preview_mode,
197-
int preview_interval,
198-
step_callback_t step_callback);
194+
float skip_layer_end);
199195

200196
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
201197
sd_image_t init_image,
@@ -222,10 +218,7 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
222218
size_t skip_layers_count,
223219
float slg_scale,
224220
float skip_layer_start,
225-
float skip_layer_end,
226-
sd_preview_policy_t preview_mode,
227-
int preview_interval,
228-
step_callback_t step_callback);
221+
float skip_layer_end);
229222

230223
SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
231224
sd_image_t init_image,

util.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ int32_t get_num_physical_cores() {
247247
static sd_progress_cb_t sd_progress_cb = NULL;
248248
void* sd_progress_cb_data = NULL;
249249

250+
static sd_preview_cb_t sd_preview_cb = NULL;
251+
sd_preview_t sd_preview_mode = SD_PREVIEW_NONE;
252+
int sd_preview_interval = 1;
253+
250254
std::u32string utf8_to_utf32(const std::string& utf8_str) {
251255
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
252256
return converter.from_bytes(utf8_str);
@@ -420,10 +424,27 @@ void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
420424
sd_progress_cb = cb;
421425
sd_progress_cb_data = data;
422426
}
423-
sd_progress_cb_t sd_get_progress_callback(){
427+
void sd_set_preview_callback(sd_preview_cb_t cb, sd_preview_t mode = SD_PREVIEW_PROJ, int interval = 1) {
428+
sd_preview_cb = cb;
429+
sd_preview_mode = mode;
430+
sd_preview_interval = interval;
431+
}
432+
433+
sd_preview_cb_t sd_get_preview_callback() {
434+
return sd_preview_cb;
435+
}
436+
437+
sd_preview_t sd_get_preview_mode() {
438+
return sd_preview_mode;
439+
}
440+
int sd_get_preview_interval() {
441+
return sd_preview_interval;
442+
}
443+
444+
sd_progress_cb_t sd_get_progress_callback() {
424445
return sd_progress_cb;
425446
}
426-
void* sd_get_progress_callback_data(){
447+
void* sd_get_progress_callback_data() {
427448
return sd_progress_cb_data;
428449
}
429450
const char* sd_get_system_info() {

util.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ std::string trim(const std::string& s);
5454

5555
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text);
5656

57+
sd_progress_cb_t sd_get_progress_callback();
58+
void* sd_get_progress_callback_data();
59+
60+
sd_preview_cb_t sd_get_preview_callback();
61+
sd_preview_t sd_get_preview_mode();
62+
int sd_get_preview_interval();
63+
5764
#define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
5865
#define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
5966
#define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)

0 commit comments

Comments
 (0)