Skip to content

Commit 8d5d16a

Browse files
committed
support for flux controls
1 parent c457adf commit 8d5d16a

File tree

5 files changed

+72
-59
lines changed

5 files changed

+72
-59
lines changed

examples/cli/main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,8 @@ int main(int argc, const char* argv[]) {
905905
input_image_buffer};
906906

907907
sd_image_t* control_image = NULL;
908-
if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
908+
if (params.control_image_path.size() > 0) {
909+
printf("load image from '%s'\n", params.control_image_path.c_str());
909910
int c = 0;
910911
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
911912
if (control_image_buffer == NULL) {

flux.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,14 @@ namespace Flux {
10321032
control = patchify(ctx, control, patch_size);
10331033

10341034
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
1035+
} else if (version == VERSION_FLUX_CONTROLS) {
1036+
GGML_ASSERT(c_concat != NULL);
1037+
1038+
ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
1039+
1040+
control = patchify(ctx, control, patch_size);
1041+
1042+
img = ggml_concat(ctx, img, control, 0);
10351043
}
10361044

10371045
if (ref_latents.size() > 0) {
@@ -1079,6 +1087,8 @@ namespace Flux {
10791087
flux_params.depth_single_blocks = 0;
10801088
if (version == VERSION_FLUX_FILL) {
10811089
flux_params.in_channels = 384;
1090+
} else if (version == VERSION_FLUX_CONTROLS) {
1091+
flux_params.in_channels = 128;
10821092
} else if (version == VERSION_FLEX_2) {
10831093
flux_params.in_channels = 196;
10841094
}

model.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,10 +1685,12 @@ SDVersion ModelLoader::get_sd_version() {
16851685
}
16861686

16871687
if (is_flux) {
1688-
is_inpaint = input_block_weight.ne[0] == 384;
1689-
if (is_inpaint) {
1688+
if (input_block_weight.ne[0] == 384) {
16901689
return VERSION_FLUX_FILL;
16911690
}
1691+
if (input_block_weight.ne[0] == 128) {
1692+
return VERSION_FLUX_CONTROLS;
1693+
}
16921694
if(input_block_weight.ne[0] == 196){
16931695
return VERSION_FLEX_2;
16941696
}

model.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLUX_CONTROLS,
3435
VERSION_FLEX_2,
3536
VERSION_COUNT,
3637
};
3738

3839
static inline bool sd_version_is_flux(SDVersion version) {
39-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) {
40+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2 ) {
4041
return true;
4142
}
4243
return false;
@@ -88,8 +89,12 @@ static inline bool sd_version_is_unet_edit(SDVersion version) {
8889
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
8990
}
9091

92+
static inline bool sd_version_is_control(SDVersion version) {
93+
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
94+
}
95+
9196
static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) {
92-
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version);
97+
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version)|| sd_version_is_control(version);
9398
}
9499

95100
enum PMVersion {

stable-diffusion.cpp

Lines changed: 49 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ class StableDiffusionGGML {
297297
// TODO: shift_factor
298298
}
299299

300-
if(version == VERSION_FLEX_2){
300+
if (sd_version_is_control(version)) {
301301
// Might need vae encode for control cond
302302
vae_decode_only = false;
303303
}
@@ -1722,6 +1722,17 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17221722
int W = width / 8;
17231723
int H = height / 8;
17241724
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
1725+
1726+
struct ggml_tensor* control_latent = NULL;
1727+
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) {
1728+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1729+
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1730+
control_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1731+
} else {
1732+
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1733+
}
1734+
}
1735+
17251736
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
17261737
int64_t mask_channels = 1;
17271738
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
@@ -1754,50 +1765,53 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
17541765
}
17551766
}
17561767
}
1757-
if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) {
1768+
1769+
if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
17581770
bool no_inpaint = concat_latent == NULL;
17591771
if (no_inpaint) {
17601772
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
17611773
}
17621774
// fill in the control image here
1763-
struct ggml_tensor* control_latents = NULL;
1764-
if (!sd_ctx->sd->use_tiny_autoencoder) {
1765-
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1766-
control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1767-
} else {
1768-
control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1769-
}
1770-
for (int64_t x = 0; x < concat_latent->ne[0]; x++) {
1771-
for (int64_t y = 0; y < concat_latent->ne[1]; y++) {
1775+
for (int64_t x = 0; x < control_latent->ne[0]; x++) {
1776+
for (int64_t y = 0; y < control_latent->ne[1]; y++) {
17721777
if (no_inpaint) {
1773-
for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) {
1778+
for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) {
17741779
// 0x16,1x1,0x16
17751780
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
17761781
}
17771782
}
1778-
for (int64_t c = 0; c < control_latents->ne[2]; c++) {
1779-
float v = ggml_tensor_get_f32(control_latents, x, y, c);
1780-
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c);
1783+
for (int64_t c = 0; c < control_latent->ne[2]; c++) {
1784+
float v = ggml_tensor_get_f32(control_latent, x, y, c);
1785+
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c);
17811786
}
17821787
}
17831788
}
1784-
// Disable controlnet
1785-
image_hint = NULL;
17861789
} else if (concat_latent == NULL) {
17871790
concat_latent = empty_latent;
17881791
}
17891792
cond.c_concat = concat_latent;
17901793
uncond.c_concat = empty_latent;
17911794
denoise_mask = NULL;
1792-
} else if (sd_version_is_unet_edit(sd_ctx->sd->version)) {
17931795
} else if (sd_version_is_unet_edit(sd_ctx->sd->version)) {
17941796
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
17951797
ggml_set_f32(empty_latent, 0);
17961798
uncond.c_concat = empty_latent;
1797-
if (concat_latent == NULL) {
1798-
concat_latent = empty_latent;
1799+
cond.c_concat = ref_latents[0];
1800+
if (cond.c_concat == NULL) {
1801+
cond.c_concat = empty_latent;
1802+
}
1803+
} else if (sd_version_is_control(sd_ctx->sd->version)) {
1804+
LOG_DEBUG("HERE");
1805+
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
1806+
ggml_set_f32(empty_latent, 0);
1807+
uncond.c_concat = empty_latent;
1808+
if (sd_version_is_control(sd_ctx->sd->version) && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
1809+
cond.c_concat = control_latent;
17991810
}
1800-
cond.c_concat = ref_latents[0];
1811+
if (cond.c_concat == NULL) {
1812+
cond.c_concat = empty_latent;
1813+
}
1814+
LOG_DEBUG("HERE");
18011815
}
18021816
SDCondition img_cond;
18031817
if (uncond.c_crossattn != NULL &&
@@ -1956,6 +1970,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
19561970
size_t t0 = ggml_time_ms();
19571971

19581972
ggml_tensor* init_latent = NULL;
1973+
ggml_tensor* init_moments = NULL;
19591974
ggml_tensor* concat_latent = NULL;
19601975
ggml_tensor* denoise_mask = NULL;
19611976
std::vector<float> sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_img_gen_params->sample_steps);
@@ -1978,8 +1993,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
19781993
sd_image_to_tensor(sd_img_gen_params->init_image.data, init_img);
19791994

19801995
if (!sd_ctx->sd->use_tiny_autoencoder) {
1981-
ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
1982-
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments);
1996+
init_moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
1997+
init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, init_moments);
19831998
} else {
19841999
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
19852000
}
@@ -1988,8 +2003,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
19882003
int64_t mask_channels = 1;
19892004
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
19902005
mask_channels = 8 * 8; // flatten the whole mask
1991-
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1992-
mask_channels = 1 + init_latent->ne[2];
2006+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
2007+
mask_channels = 1 + init_latent->ne[2];
19932008
}
19942009
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
19952010
sd_apply_mask(init_img, mask_img, masked_img);
@@ -2024,38 +2039,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
20242039
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y);
20252040
}
20262041
}
2027-
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
2028-
float m = ggml_tensor_get_f32(mask_img, mx, my);
2029-
// masked image
2030-
for (int k = 0; k < masked_latent->ne[2]; k++) {
2031-
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
2032-
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
2033-
}
2034-
// downsampled mask
2035-
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
2036-
// control (todo: support this)
2037-
for (int k = 0; k < masked_latent->ne[2]; k++) {
2038-
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
2039-
}
2040-
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
2041-
float m = ggml_tensor_get_f32(mask_img, mx, my);
2042-
// masked image
2043-
for (int k = 0; k < masked_latent->ne[2]; k++) {
2044-
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
2045-
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
2046-
}
2047-
// downsampled mask
2048-
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
2049-
// control (todo: support this)
2050-
for (int k = 0; k < masked_latent->ne[2]; k++) {
2051-
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
2052-
}
2053-
} else {
2042+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
20542043
float m = ggml_tensor_get_f32(mask_img, mx, my);
2055-
ggml_tensor_set_f32(concat_latent, m, ix, iy, 0);
2044+
// masked image
20562045
for (int k = 0; k < masked_latent->ne[2]; k++) {
20572046
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
2058-
ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels);
2047+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
2048+
}
2049+
// downsampled mask
2050+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
2051+
// control (todo: support this)
2052+
for (int k = 0; k < masked_latent->ne[2]; k++) {
2053+
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
20592054
}
20602055
}
20612056
}

0 commit comments

Comments
 (0)