Skip to content

Commit fb604b7

Browse files
committed
Support Flex-2
1 parent 5650e56 commit fb604b7

File tree

6 files changed

+105
-11
lines changed

6 files changed

+105
-11
lines changed

examples/cli/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,7 @@ int main(int argc, const char* argv[]) {
10251025
}
10261026

10271027
sd_image_t* control_image = NULL;
1028-
if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) {
1028+
if (params.control_image_path.size() > 0) {
10291029
int c = 0;
10301030
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
10311031
if (control_image_buffer == NULL) {

flux.hpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,8 @@ namespace Flux {
984984
struct ggml_tensor* pe,
985985
struct ggml_tensor* mod_index_arange = NULL,
986986
std::vector<ggml_tensor*> ref_latents = {},
987-
std::vector<int> skip_layers = {}) {
987+
std::vector<int> skip_layers = {},
988+
SDVersion version = VERSION_FLUX) {
988989
// Forward pass of DiT.
989990
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
990991
// timestep: (N,) tensor of diffusion timesteps
@@ -1007,14 +1008,30 @@ namespace Flux {
10071008
auto img = process_img(ctx, x);
10081009
uint64_t img_tokens = img->ne[1];
10091010

1010-
if (c_concat != NULL) {
1011+
if (version == VERSION_FLUX_FILL) {
1012+
GGML_ASSERT(c_concat != NULL);
10111013
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
10121014
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
10131015

10141016
masked = process_img(ctx, masked);
10151017
mask = process_img(ctx, mask);
10161018

10171019
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
1020+
} else if (version == VERSION_FLEX_2) {
1021+
GGML_ASSERT(c_concat != NULL);
1022+
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
1023+
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
1024+
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
1025+
1026+
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
1027+
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
1028+
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);
1029+
1030+
masked = patchify(ctx, masked, patch_size);
1031+
mask = patchify(ctx, mask, patch_size);
1032+
control = patchify(ctx, control, patch_size);
1033+
1034+
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
10181035
}
10191036

10201037
if (ref_latents.size() > 0) {
@@ -1055,13 +1072,15 @@ namespace Flux {
10551072
SDVersion version = VERSION_FLUX,
10561073
bool flash_attn = false,
10571074
bool use_mask = false)
1058-
: GGMLRunner(backend), use_mask(use_mask) {
1075+
: GGMLRunner(backend), version(version), use_mask(use_mask) {
10591076
flux_params.flash_attn = flash_attn;
10601077
flux_params.guidance_embed = false;
10611078
flux_params.depth = 0;
10621079
flux_params.depth_single_blocks = 0;
10631080
if (version == VERSION_FLUX_FILL) {
10641081
flux_params.in_channels = 384;
1082+
} else if (version == VERSION_FLEX_2) {
1083+
flux_params.in_channels = 196;
10651084
}
10661085
for (auto pair : tensor_types) {
10671086
std::string tensor_name = pair.first;
@@ -1171,7 +1190,8 @@ namespace Flux {
11711190
pe,
11721191
mod_index_arange,
11731192
ref_latents,
1174-
skip_layers);
1193+
skip_layers,
1194+
version);
11751195

11761196
ggml_build_forward_expand(gf, out);
11771197

model.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,6 +1689,9 @@ SDVersion ModelLoader::get_sd_version() {
16891689
if (is_inpaint) {
16901690
return VERSION_FLUX_FILL;
16911691
}
1692+
if(input_block_weight.ne[0] == 196){
1693+
return VERSION_FLEX_2;
1694+
}
16921695
return VERSION_FLUX;
16931696
}
16941697

model.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ enum SDVersion {
3131
VERSION_SD3,
3232
VERSION_FLUX,
3333
VERSION_FLUX_FILL,
34+
VERSION_FLEX_2,
3435
VERSION_COUNT,
3536
};
3637

3738
static inline bool sd_version_is_flux(SDVersion version) {
38-
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
39+
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2 ) {
3940
return true;
4041
}
4142
return false;
@@ -70,7 +71,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
7071
}
7172

7273
static inline bool sd_version_is_inpaint(SDVersion version) {
73-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
74+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
7475
return true;
7576
}
7677
return false;

stable-diffusion.cpp

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class StableDiffusionGGML {
9595
std::shared_ptr<DiffusionModel> diffusion_model;
9696
std::shared_ptr<AutoEncoderKL> first_stage_model;
9797
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
98-
std::shared_ptr<ControlNet> control_net;
98+
std::shared_ptr<ControlNet> control_net = NULL;
9999
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
100100
std::shared_ptr<LoraModel> pmid_lora;
101101
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
@@ -314,6 +314,11 @@ class StableDiffusionGGML {
314314
// TODO: shift_factor
315315
}
316316

317+
if(version == VERSION_FLEX_2){
318+
// Might need vae encode for control cond
319+
vae_decode_only = false;
320+
}
321+
317322
if (version == VERSION_SVD) {
318323
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, model_loader.tensor_storages_types);
319324
clip_vision->alloc_params_buffer();
@@ -922,7 +927,7 @@ class StableDiffusionGGML {
922927

923928
std::vector<struct ggml_tensor*> controls;
924929

925-
if (control_hint != NULL) {
930+
if (control_hint != NULL && control_net != NULL) {
926931
control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector);
927932
controls = control_net->controls;
928933
// print_ggml_tensor(controls[12]);
@@ -961,7 +966,7 @@ class StableDiffusionGGML {
961966
float* negative_data = NULL;
962967
if (has_unconditioned) {
963968
// uncond
964-
if (control_hint != NULL) {
969+
if (control_hint != NULL && control_net != NULL) {
965970
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
966971
controls = control_net->controls;
967972
}
@@ -1511,6 +1516,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15111516
int64_t mask_channels = 1;
15121517
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
15131518
mask_channels = 8 * 8; // flatten the whole mask
1519+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1520+
mask_channels = 1 + init_latent->ne[2];
15141521
}
15151522
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
15161523
// no mask, set the whole image as masked
@@ -1524,6 +1531,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15241531
for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) {
15251532
ggml_tensor_set_f32(empty_latent, 1, x, y, c);
15261533
}
1534+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1535+
for (int64_t c = 0; c < empty_latent->ne[2]; c++) {
1536+
// 0x16,1x1,0x16
1537+
ggml_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c);
1538+
}
15271539
} else {
15281540
ggml_tensor_set_f32(empty_latent, 1, x, y, 0);
15291541
for (int64_t c = 1; c < empty_latent->ne[2]; c++) {
@@ -1532,7 +1544,36 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
15321544
}
15331545
}
15341546
}
1535-
if (concat_latent == NULL) {
1547+
if (sd_ctx->sd->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd->control_net == NULL) {
1548+
bool no_inpaint = concat_latent == NULL;
1549+
if (no_inpaint) {
1550+
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
1551+
}
1552+
// fill in the control image here
1553+
struct ggml_tensor* control_latents = NULL;
1554+
if (!sd_ctx->sd->use_tiny_autoencoder) {
1555+
struct ggml_tensor* control_moments = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1556+
control_latents = sd_ctx->sd->get_first_stage_encoding(work_ctx, control_moments);
1557+
} else {
1558+
control_latents = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
1559+
}
1560+
for (int64_t x = 0; x < concat_latent->ne[0]; x++) {
1561+
for (int64_t y = 0; y < concat_latent->ne[1]; y++) {
1562+
if (no_inpaint) {
1563+
for (int64_t c = 0; c < concat_latent->ne[2] - control_latents->ne[2]; c++) {
1564+
// 0x16,1x1,0x16
1565+
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
1566+
}
1567+
}
1568+
for (int64_t c = 0; c < control_latents->ne[2]; c++) {
1569+
float v = ggml_tensor_get_f32(control_latents, x, y, c);
1570+
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latents->ne[2] + c);
1571+
}
1572+
}
1573+
}
1574+
// Disable controlnet
1575+
image_hint = NULL;
1576+
} else if (concat_latent == NULL) {
15361577
concat_latent = empty_latent;
15371578
}
15381579
cond.c_concat = concat_latent;
@@ -1819,6 +1860,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18191860
int64_t mask_channels = 1;
18201861
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
18211862
mask_channels = 8 * 8; // flatten the whole mask
1863+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1864+
mask_channels = 1 + init_latent->ne[2];
18221865
}
18231866
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
18241867
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
@@ -1850,6 +1893,32 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
18501893
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y);
18511894
}
18521895
}
1896+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1897+
float m = ggml_tensor_get_f32(mask_img, mx, my);
1898+
// masked image
1899+
for (int k = 0; k < masked_latent->ne[2]; k++) {
1900+
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
1901+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
1902+
}
1903+
// downsampled mask
1904+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
1905+
// control (todo: support this)
1906+
for (int k = 0; k < masked_latent->ne[2]; k++) {
1907+
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
1908+
}
1909+
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
1910+
float m = ggml_tensor_get_f32(mask_img, mx, my);
1911+
// masked image
1912+
for (int k = 0; k < masked_latent->ne[2]; k++) {
1913+
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
1914+
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
1915+
}
1916+
// downsampled mask
1917+
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
1918+
// control (todo: support this)
1919+
for (int k = 0; k < masked_latent->ne[2]; k++) {
1920+
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
1921+
}
18531922
} else {
18541923
float m = ggml_tensor_get_f32(mask_img, mx, my);
18551924
ggml_tensor_set_f32(concat_latent, m, ix, iy, 0);

vae.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ struct AutoEncoderKL : public GGMLRunner {
559559
bool decode_graph,
560560
struct ggml_tensor** output,
561561
struct ggml_context* output_ctx = NULL) {
562+
GGML_ASSERT(!decode_only || decode_graph);
562563
auto get_graph = [&]() -> struct ggml_cgraph* {
563564
return build_graph(z, decode_graph);
564565
};

0 commit comments

Comments
 (0)