@@ -95,7 +95,7 @@ class StableDiffusionGGML {
95
95
std::shared_ptr<DiffusionModel> diffusion_model;
96
96
std::shared_ptr<AutoEncoderKL> first_stage_model;
97
97
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
98
- std::shared_ptr<ControlNet> control_net;
98
+ std::shared_ptr<ControlNet> control_net = NULL ;
99
99
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
100
100
std::shared_ptr<LoraModel> pmid_lora;
101
101
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
@@ -314,6 +314,11 @@ class StableDiffusionGGML {
314
314
// TODO: shift_factor
315
315
}
316
316
317
+ if (version == VERSION_FLEX_2){
318
+ // Might need vae encode for control cond
319
+ vae_decode_only = false ;
320
+ }
321
+
317
322
if (version == VERSION_SVD) {
318
323
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, model_loader.tensor_storages_types );
319
324
clip_vision->alloc_params_buffer ();
@@ -922,7 +927,7 @@ class StableDiffusionGGML {
922
927
923
928
std::vector<struct ggml_tensor *> controls;
924
929
925
- if (control_hint != NULL ) {
930
+ if (control_hint != NULL && control_net != NULL ) {
926
931
control_net->compute (n_threads, noised_input, control_hint, timesteps, cond.c_crossattn , cond.c_vector );
927
932
controls = control_net->controls ;
928
933
// print_ggml_tensor(controls[12]);
@@ -961,7 +966,7 @@ class StableDiffusionGGML {
961
966
float * negative_data = NULL ;
962
967
if (has_unconditioned) {
963
968
// uncond
964
- if (control_hint != NULL ) {
969
+ if (control_hint != NULL && control_net != NULL ) {
965
970
control_net->compute (n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn , uncond.c_vector );
966
971
controls = control_net->controls ;
967
972
}
@@ -1511,6 +1516,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1511
1516
int64_t mask_channels = 1 ;
1512
1517
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
1513
1518
mask_channels = 8 * 8 ; // flatten the whole mask
1519
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1520
+ mask_channels = 1 + init_latent->ne [2 ];
1514
1521
}
1515
1522
auto empty_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], mask_channels + init_latent->ne [2 ], 1 );
1516
1523
// no mask, set the whole image as masked
@@ -1524,6 +1531,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1524
1531
for (int64_t c = init_latent->ne [2 ]; c < empty_latent->ne [2 ]; c++) {
1525
1532
ggml_tensor_set_f32 (empty_latent, 1 , x, y, c);
1526
1533
}
1534
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1535
+ for (int64_t c = 0 ; c < empty_latent->ne [2 ]; c++) {
1536
+ // 0x16,1x1,0x16
1537
+ ggml_tensor_set_f32 (empty_latent, c == init_latent->ne [2 ], x, y, c);
1538
+ }
1527
1539
} else {
1528
1540
ggml_tensor_set_f32 (empty_latent, 1 , x, y, 0 );
1529
1541
for (int64_t c = 1 ; c < empty_latent->ne [2 ]; c++) {
@@ -1532,7 +1544,36 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx,
1532
1544
}
1533
1545
}
1534
1546
}
1535
- if (concat_latent == NULL ) {
1547
+ if (sd_ctx->sd ->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd ->control_net == NULL ) {
1548
+ bool no_inpaint = concat_latent == NULL ;
1549
+ if (no_inpaint) {
1550
+ concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], mask_channels + init_latent->ne [2 ], 1 );
1551
+ }
1552
+ // fill in the control image here
1553
+ struct ggml_tensor * control_latents = NULL ;
1554
+ if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1555
+ struct ggml_tensor * control_moments = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1556
+ control_latents = sd_ctx->sd ->get_first_stage_encoding (work_ctx, control_moments);
1557
+ } else {
1558
+ control_latents = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1559
+ }
1560
+ for (int64_t x = 0 ; x < concat_latent->ne [0 ]; x++) {
1561
+ for (int64_t y = 0 ; y < concat_latent->ne [1 ]; y++) {
1562
+ if (no_inpaint) {
1563
+ for (int64_t c = 0 ; c < concat_latent->ne [2 ] - control_latents->ne [2 ]; c++) {
1564
+ // 0x16,1x1,0x16
1565
+ ggml_tensor_set_f32 (concat_latent, c == init_latent->ne [2 ], x, y, c);
1566
+ }
1567
+ }
1568
+ for (int64_t c = 0 ; c < control_latents->ne [2 ]; c++) {
1569
+ float v = ggml_tensor_get_f32 (control_latents, x, y, c);
1570
+ ggml_tensor_set_f32 (concat_latent, v, x, y, concat_latent->ne [2 ] - control_latents->ne [2 ] + c);
1571
+ }
1572
+ }
1573
+ }
1574
+ // Disable controlnet
1575
+ image_hint = NULL ;
1576
+ } else if (concat_latent == NULL ) {
1536
1577
concat_latent = empty_latent;
1537
1578
}
1538
1579
cond.c_concat = concat_latent;
@@ -1819,6 +1860,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1819
1860
int64_t mask_channels = 1 ;
1820
1861
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
1821
1862
mask_channels = 8 * 8 ; // flatten the whole mask
1863
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1864
+ mask_channels = 1 + init_latent->ne [2 ];
1822
1865
}
1823
1866
ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
1824
1867
// Restore init_img (encode_first_stage has side effects) TODO: remove the side effects?
@@ -1850,6 +1893,32 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
1850
1893
ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ] + x * 8 + y);
1851
1894
}
1852
1895
}
1896
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1897
+ float m = ggml_tensor_get_f32 (mask_img, mx, my);
1898
+ // masked image
1899
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
1900
+ float v = ggml_tensor_get_f32 (masked_latent, ix, iy, k);
1901
+ ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
1902
+ }
1903
+ // downsampled mask
1904
+ ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ]);
1905
+ // control (todo: support this)
1906
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
1907
+ ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
1908
+ }
1909
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1910
+ float m = ggml_tensor_get_f32 (mask_img, mx, my);
1911
+ // masked image
1912
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
1913
+ float v = ggml_tensor_get_f32 (masked_latent, ix, iy, k);
1914
+ ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
1915
+ }
1916
+ // downsampled mask
1917
+ ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ]);
1918
+ // control (todo: support this)
1919
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
1920
+ ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
1921
+ }
1853
1922
} else {
1854
1923
float m = ggml_tensor_get_f32 (mask_img, mx, my);
1855
1924
ggml_tensor_set_f32 (concat_latent, m, ix, iy, 0 );
0 commit comments