Skip to content

Commit ac82a1c

Browse files
authored
support kontext inference (#114)
* support kontext inference * fix kontext unittest * fix
1 parent 4420f27 commit ac82a1c

File tree

6 files changed

+35
-4
lines changed

6 files changed

+35
-4
lines changed

diffsynth_engine/pipelines/controlnet_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
@dataclass
1010
class ControlNetParams:
11-
scale: float
1211
image: ImageType
12+
scale: float = 1.0
1313
model: Optional[nn.Module] = None
1414
mask: Optional[ImageType] = None
1515
control_start: float = 0

diffsynth_engine/pipelines/flux_image.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,10 @@ class ControlType(Enum):
419419
normal = "normal"
420420
bfl_control = "bfl_control"
421421
bfl_fill = "bfl_fill"
422+
bfl_kontext = "bfl_kontext"
422423

423424
def get_in_channel(self):
424-
if self == ControlType.normal:
425+
if self in [ControlType.normal, ControlType.bfl_kontext]:
425426
return 64
426427
elif self == ControlType.bfl_control:
427428
return 128
@@ -764,9 +765,15 @@ def predict_noise(
764765
current_step: int,
765766
total_step: int,
766767
):
768+
origin_latents_shape = latents.shape
767769
if self.control_type != ControlType.normal:
768770
controlnet_param = controlnet_params[0]
769-
latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=1)
771+
if self.control_type == ControlType.bfl_kontext:
772+
latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=2)
773+
image_ids = image_ids.repeat(1, 2, 1)
774+
image_ids[:, image_ids.shape[1] // 2 :, 0] += 1
775+
else:
776+
latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=1)
770777
latents = latents.to(self.dtype)
771778
controlnet_params = []
772779

@@ -797,6 +804,8 @@ def predict_noise(
797804
controlnet_double_block_output=double_block_output,
798805
controlnet_single_block_output=single_block_output,
799806
)
807+
if self.control_type == ControlType.bfl_kontext:
808+
noise_pred = noise_pred[:, :, : origin_latents_shape[2], : origin_latents_shape[3]]
800809
return noise_pred
801810

802811
def prepare_latents(

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ dependencies = [
3030
"pillow",
3131
"imageio[ffmpeg]",
3232
"yunchang ; sys_platform == 'linux'",
33-
"onnxruntime"
33+
"onnxruntime",
34+
"opencv-python"
3435
]
3536

3637
[project.optional-dependencies]
1010 KB
Loading
565 KB
Loading

tests/test_pipelines/test_flux_bfl_image.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,5 +108,26 @@ def test_fill_txt2img(self):
108108
self.assertImageEqualAndSaveFailed(image, "flux/flux_bfl_fill.png", threshold=0.99)
109109

110110

111+
class TestFLUXBFLKontextImage(ImageTestCase):
112+
@classmethod
113+
def setUpClass(cls):
114+
kontext_model_path = fetch_model(
115+
"black-forest-labs/FLUX.1-Kontext-dev", revision="master", path="flux1-kontext-dev.safetensors"
116+
)
117+
cls.pipe = FluxImagePipeline.from_pretrained(kontext_model_path, control_type=ControlType.bfl_kontext).eval()
118+
119+
def test_kontext_image(self):
120+
image = self.pipe(
121+
prompt="Make the wall color to red",
122+
height=1024,
123+
width=1024,
124+
controlnet_params=ControlNetParams(image=self.get_input_image("flux_kontext_input.png")),
125+
cfg_scale=1.0,
126+
seed=42,
127+
num_inference_steps=30,
128+
)
129+
self.assertImageEqualAndSaveFailed(image, "flux/flux_bfl_kontext.png", threshold=0.99)
130+
131+
111132
if __name__ == "__main__":
112133
unittest.main()

0 commit comments

Comments
 (0)