@@ -297,7 +297,7 @@ class StableDiffusionGGML {
297
297
// TODO: shift_factor
298
298
}
299
299
300
- if ( version == VERSION_FLEX_2) {
300
+ if ( sd_version_is_control ( version)) {
301
301
// Might need vae encode for control cond
302
302
vae_decode_only = false ;
303
303
}
@@ -1722,6 +1722,17 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
1722
1722
int W = width / 8 ;
1723
1723
int H = height / 8 ;
1724
1724
LOG_INFO (" sampling using %s method" , sampling_methods_str[sample_method]);
1725
+
1726
+ struct ggml_tensor * control_latent = NULL ;
1727
+ if (sd_version_is_control (sd_ctx->sd ->version ) && image_hint != NULL ) {
1728
+ if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1729
+ struct ggml_tensor * control_moments = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1730
+ control_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, control_moments);
1731
+ } else {
1732
+ control_latent = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1733
+ }
1734
+ }
1735
+
1725
1736
if (sd_version_is_inpaint (sd_ctx->sd ->version )) {
1726
1737
int64_t mask_channels = 1 ;
1727
1738
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
@@ -1754,50 +1765,53 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
1754
1765
}
1755
1766
}
1756
1767
}
1757
- if (sd_ctx->sd ->version == VERSION_FLEX_2 && image_hint != NULL && sd_ctx->sd ->control_net == NULL ) {
1768
+
1769
+ if (sd_ctx->sd ->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd ->control_net == NULL ) {
1758
1770
bool no_inpaint = concat_latent == NULL ;
1759
1771
if (no_inpaint) {
1760
1772
concat_latent = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, init_latent->ne [0 ], init_latent->ne [1 ], mask_channels + init_latent->ne [2 ], 1 );
1761
1773
}
1762
1774
// fill in the control image here
1763
- struct ggml_tensor * control_latents = NULL ;
1764
- if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1765
- struct ggml_tensor * control_moments = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1766
- control_latents = sd_ctx->sd ->get_first_stage_encoding (work_ctx, control_moments);
1767
- } else {
1768
- control_latents = sd_ctx->sd ->encode_first_stage (work_ctx, image_hint);
1769
- }
1770
- for (int64_t x = 0 ; x < concat_latent->ne [0 ]; x++) {
1771
- for (int64_t y = 0 ; y < concat_latent->ne [1 ]; y++) {
1775
+ for (int64_t x = 0 ; x < control_latent->ne [0 ]; x++) {
1776
+ for (int64_t y = 0 ; y < control_latent->ne [1 ]; y++) {
1772
1777
if (no_inpaint) {
1773
- for (int64_t c = 0 ; c < concat_latent->ne [2 ] - control_latents ->ne [2 ]; c++) {
1778
+ for (int64_t c = 0 ; c < concat_latent->ne [2 ] - control_latent ->ne [2 ]; c++) {
1774
1779
// 0x16,1x1,0x16
1775
1780
ggml_tensor_set_f32 (concat_latent, c == init_latent->ne [2 ], x, y, c);
1776
1781
}
1777
1782
}
1778
- for (int64_t c = 0 ; c < control_latents ->ne [2 ]; c++) {
1779
- float v = ggml_tensor_get_f32 (control_latents , x, y, c);
1780
- ggml_tensor_set_f32 (concat_latent, v, x, y, concat_latent->ne [2 ] - control_latents ->ne [2 ] + c);
1783
+ for (int64_t c = 0 ; c < control_latent ->ne [2 ]; c++) {
1784
+ float v = ggml_tensor_get_f32 (control_latent , x, y, c);
1785
+ ggml_tensor_set_f32 (concat_latent, v, x, y, concat_latent->ne [2 ] - control_latent ->ne [2 ] + c);
1781
1786
}
1782
1787
}
1783
1788
}
1784
- // Disable controlnet
1785
- image_hint = NULL ;
1786
1789
} else if (concat_latent == NULL ) {
1787
1790
concat_latent = empty_latent;
1788
1791
}
1789
1792
cond.c_concat = concat_latent;
1790
1793
uncond.c_concat = empty_latent;
1791
1794
denoise_mask = NULL ;
1792
- } else if (sd_version_is_unet_edit (sd_ctx->sd ->version )) {
1793
1795
} else if (sd_version_is_unet_edit (sd_ctx->sd ->version )) {
1794
1796
auto empty_latent = ggml_dup_tensor (work_ctx, init_latent);
1795
1797
ggml_set_f32 (empty_latent, 0 );
1796
1798
uncond.c_concat = empty_latent;
1797
- if (concat_latent == NULL ) {
1798
- concat_latent = empty_latent;
1799
+ cond.c_concat = ref_latents[0 ];
1800
+ if (cond.c_concat == NULL ) {
1801
+ cond.c_concat = empty_latent;
1802
+ }
1803
+ } else if (sd_version_is_control (sd_ctx->sd ->version )) {
1804
+ LOG_DEBUG (" HERE" );
1805
+ auto empty_latent = ggml_dup_tensor (work_ctx, init_latent);
1806
+ ggml_set_f32 (empty_latent, 0 );
1807
+ uncond.c_concat = empty_latent;
1808
+ if (sd_version_is_control (sd_ctx->sd ->version ) && control_latent != NULL && sd_ctx->sd ->control_net == NULL ) {
1809
+ cond.c_concat = control_latent;
1799
1810
}
1800
- cond.c_concat = ref_latents[0 ];
1811
+ if (cond.c_concat == NULL ) {
1812
+ cond.c_concat = empty_latent;
1813
+ }
1814
+ LOG_DEBUG (" HERE" );
1801
1815
}
1802
1816
SDCondition img_cond;
1803
1817
if (uncond.c_crossattn != NULL &&
@@ -1956,6 +1970,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
1956
1970
size_t t0 = ggml_time_ms ();
1957
1971
1958
1972
ggml_tensor* init_latent = NULL ;
1973
+ ggml_tensor* init_moments = NULL ;
1959
1974
ggml_tensor* concat_latent = NULL ;
1960
1975
ggml_tensor* denoise_mask = NULL ;
1961
1976
std::vector<float > sigmas = sd_ctx->sd ->denoiser ->get_sigmas (sd_img_gen_params->sample_steps );
@@ -1978,8 +1993,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
1978
1993
sd_image_to_tensor (sd_img_gen_params->init_image .data , init_img);
1979
1994
1980
1995
if (!sd_ctx->sd ->use_tiny_autoencoder ) {
1981
- ggml_tensor* moments = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
1982
- init_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, moments );
1996
+ init_moments = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
1997
+ init_latent = sd_ctx->sd ->get_first_stage_encoding (work_ctx, init_moments );
1983
1998
} else {
1984
1999
init_latent = sd_ctx->sd ->encode_first_stage (work_ctx, init_img);
1985
2000
}
@@ -1988,8 +2003,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
1988
2003
int64_t mask_channels = 1 ;
1989
2004
if (sd_ctx->sd ->version == VERSION_FLUX_FILL) {
1990
2005
mask_channels = 8 * 8 ; // flatten the whole mask
1991
- } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
1992
- mask_channels = 1 + init_latent->ne [2 ];
2006
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
2007
+ mask_channels = 1 + init_latent->ne [2 ];
1993
2008
}
1994
2009
ggml_tensor* masked_img = ggml_new_tensor_4d (work_ctx, GGML_TYPE_F32, width, height, 3 , 1 );
1995
2010
sd_apply_mask (init_img, mask_img, masked_img);
@@ -2024,38 +2039,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
2024
2039
ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ] + x * 8 + y);
2025
2040
}
2026
2041
}
2027
- } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
2028
- float m = ggml_tensor_get_f32 (mask_img, mx, my);
2029
- // masked image
2030
- for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2031
- float v = ggml_tensor_get_f32 (masked_latent, ix, iy, k);
2032
- ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
2033
- }
2034
- // downsampled mask
2035
- ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ]);
2036
- // control (todo: support this)
2037
- for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2038
- ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
2039
- }
2040
- } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
2041
- float m = ggml_tensor_get_f32 (mask_img, mx, my);
2042
- // masked image
2043
- for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2044
- float v = ggml_tensor_get_f32 (masked_latent, ix, iy, k);
2045
- ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
2046
- }
2047
- // downsampled mask
2048
- ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ]);
2049
- // control (todo: support this)
2050
- for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2051
- ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
2052
- }
2053
- } else {
2042
+ } else if (sd_ctx->sd ->version == VERSION_FLEX_2) {
2054
2043
float m = ggml_tensor_get_f32 (mask_img, mx, my);
2055
- ggml_tensor_set_f32 (concat_latent, m, ix, iy, 0 );
2044
+ // masked image
2056
2045
for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2057
2046
float v = ggml_tensor_get_f32 (masked_latent, ix, iy, k);
2058
- ggml_tensor_set_f32 (concat_latent, v, ix, iy, k + mask_channels);
2047
+ ggml_tensor_set_f32 (concat_latent, v, ix, iy, k);
2048
+ }
2049
+ // downsampled mask
2050
+ ggml_tensor_set_f32 (concat_latent, m, ix, iy, masked_latent->ne [2 ]);
2051
+ // control (todo: support this)
2052
+ for (int k = 0 ; k < masked_latent->ne [2 ]; k++) {
2053
+ ggml_tensor_set_f32 (concat_latent, 0 , ix, iy, masked_latent->ne [2 ] + 1 + k);
2059
2054
}
2060
2055
}
2061
2056
}
0 commit comments