Skip to content

Commit 2b6d9b1

Browse files
committed
Fix Flex 2 inpaint
1 parent 8d5d16a commit 2b6d9b1

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

ggml_extend.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,8 @@ __STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data,
380380

381381
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
382382
struct ggml_tensor* mask,
383-
struct ggml_tensor* output) {
383+
struct ggml_tensor* output,
384+
float masked_value = 0.5f) {
384385
int64_t width = output->ne[0];
385386
int64_t height = output->ne[1];
386387
int64_t channels = output->ne[2];
@@ -389,11 +390,14 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
389390
GGML_ASSERT(output->type == GGML_TYPE_F32);
390391
for (int ix = 0; ix < width; ix++) {
391392
for (int iy = 0; iy < height; iy++) {
392-
float m = ggml_tensor_get_f32(mask, ix, iy);
393+
int mx = (int)(ix * rescale_mx);
394+
int my = (int)(iy * rescale_my);
395+
float m = ggml_tensor_get_f32(mask, mx, my);
393396
m = round(m); // inpaint models need binary masks
394-
ggml_tensor_set_f32(mask, m, ix, iy);
397+
ggml_tensor_set_f32(mask, m, mx, my);
395398
for (int k = 0; k < channels; k++) {
396-
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
399+
float value = ggml_tensor_get_f32(image_data, ix, iy, k);
400+
value = (1 - m) * (value - masked_value) + masked_value;
397401
ggml_tensor_set_f32(output, value, ix, iy, k);
398402
}
399403
}

stable-diffusion.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2006,14 +2006,21 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
20062006
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
20072007
mask_channels = 1 + init_latent->ne[2];
20082008
}
2009-
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
2010-
sd_apply_mask(init_img, mask_img, masked_img);
20112009
ggml_tensor* masked_latent = NULL;
2012-
if (!sd_ctx->sd->use_tiny_autoencoder) {
2013-
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
2014-
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
2010+
if (sd_ctx->sd->version != VERSION_FLEX_2) {
2011+
// most inpaint models mask before vae
2012+
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
2013+
sd_apply_mask(init_img, mask_img, masked_img);
2014+
if (!sd_ctx->sd->use_tiny_autoencoder) {
2015+
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
2016+
masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
2017+
} else {
2018+
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
2019+
}
20152020
} else {
2016-
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
2021+
// mask after vae
2022+
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
2023+
sd_apply_mask(init_latent, mask_img, masked_latent, 0.);
20172024
}
20182025
concat_latent = ggml_new_tensor_4d(work_ctx,
20192026
GGML_TYPE_F32,

0 commit comments

Comments
 (0)