Skip to content

Commit 14064f6

Browse files
committed
Added support for Qwen-Image-Edit-Plus, added Arch.qwen_e_p
1 parent 231c239 commit 14064f6

File tree

7 files changed

+62
-23
lines changed

7 files changed

+62
-23
lines changed

ai_diffusion/comfy_workflow.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,24 @@ def text_encode_qwen_image_edit(
606606
"TextEncodeQwenImageEdit", 1, clip=clip, vae=vae, image=image, prompt=prompt
607607
)
608608

609+
def text_encode_qwen_image_edit_plus(
610+
self, clip: Output, vae: Output | None, images: list[Output], prompt: str | Output
611+
):
612+
image1 = images[0] if len(images) > 0 else None
613+
image2 = images[1] if len(images) > 1 else None
614+
image3 = images[2] if len(images) > 2 else None
615+
616+
return self.add(
617+
"TextEncodeQwenImageEditPlus",
618+
1,
619+
clip=clip,
620+
vae=vae,
621+
image1=image1,
622+
image2=image2,
623+
image3=image3,
624+
prompt=prompt,
625+
)
626+
609627
def background_region(self, conditioning: Output):
610628
return self.add("ETN_BackgroundRegion", 1, conditioning=conditioning)
611629

ai_diffusion/resolution.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def compute(extent: Extent, arch: Arch, style: Style | None = None):
152152
Arch.sd3: (512, 1536, 512**2, 1536**2),
153153
Arch.flux: (256, 2048, 512**2, 2048**2),
154154
Arch.qwen: (256, 2048, 512**2, 2048**2),
155+
Arch.qwen_e: (256, 2048, 512**2, 2048**2),
156+
Arch.qwen_e_p: (256, 2048, 512**2, 2048**2),
155157
}[arch]
156158
else:
157159
range_offset = multiple_of(round(0.2 * style.preferred_resolution), 8)

ai_diffusion/resources.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class Arch(Enum):
8585
chroma = "Chroma"
8686
qwen = "Qwen"
8787
qwen_e = "Qwen Edit"
88+
qwen_e_p = "Qwen Edit Plus"
8889

8990
auto = "Automatic"
9091
all = "All"
@@ -109,13 +110,12 @@ def from_string(string: str, model_type: str = "eps", filename: str | None = Non
109110
return Arch.illu_v
110111
if string == "chroma":
111112
return Arch.chroma
112-
if (
113-
string in ("qwen", "qwen_image", "qwen-image")
114-
and filename
115-
and "edit" in filename.lower()
116-
):
117-
return Arch.qwen_e
118-
if string in ("qwen", "qwen_image", "qwen-image"):
113+
if string == "qwen-image" and filename and "edit" in filename.lower():
114+
if "2509" in filename.lower():
115+
return Arch.qwen_e_p
116+
else:
117+
return Arch.qwen_e
118+
if string == "qwen-image":
119119
return Arch.qwen
120120
return None
121121

@@ -167,7 +167,7 @@ def supports_cfg(self):
167167

168168
@property
169169
def is_edit(self): # edit models make changes to input images
170-
return self in [Arch.flux_k, Arch.qwen_e]
170+
return self in [Arch.flux_k, Arch.qwen_e, Arch.qwen_e_p]
171171

172172
@property
173173
def is_sdxl_like(self):
@@ -178,6 +178,10 @@ def is_sdxl_like(self):
178178
def is_flux_like(self):
179179
return self in [Arch.flux, Arch.flux_k]
180180

181+
@property
182+
def is_qwen_like(self):
183+
return self in [Arch.qwen, Arch.qwen_e, Arch.qwen_e_p]
184+
181185
@property
182186
def text_encoders(self):
183187
match self:
@@ -191,7 +195,7 @@ def text_encoders(self):
191195
return ["clip_l", "t5"]
192196
case Arch.chroma:
193197
return ["t5"]
194-
case Arch.qwen | Arch.qwen_e:
198+
case Arch.qwen | Arch.qwen_e | Arch.qwen_e_p:
195199
return ["qwen"]
196200
raise ValueError(f"Unsupported architecture: {self}")
197201

@@ -208,6 +212,7 @@ def list():
208212
Arch.chroma,
209213
Arch.qwen,
210214
Arch.qwen_e,
215+
Arch.qwen_e_p,
211216
]
212217

213218

@@ -714,6 +719,7 @@ def is_required(kind: ResourceKind, arch: Arch, identifier: ControlMode | Upscal
714719
resource_id(ResourceKind.vae, Arch.chroma, "default"): ["flux", "ae.s"],
715720
resource_id(ResourceKind.vae, Arch.qwen, "default"): ["qwen"],
716721
resource_id(ResourceKind.vae, Arch.qwen_e, "default"): ["qwen"],
722+
resource_id(ResourceKind.vae, Arch.qwen_e_p, "default"): ["qwen"],
717723
}
718724
# fmt: on
719725

ai_diffusion/ui/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def _workload_matches(self, item: PackageItem):
202202
Arch.chroma,
203203
Arch.qwen,
204204
Arch.qwen_e,
205+
Arch.qwen_e_p,
205206
]
206207
)
207208

ai_diffusion/ui/style.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -765,8 +765,8 @@ def _change_style(self):
765765
self._read_style(self.current_style)
766766

767767
def _open_checkpoints_folder(self):
768-
arch = arch = resolve_arch(self.current_style, root.connection.client_if_connected)
769-
if arch.is_flux_like or arch in (Arch.chroma, Arch.qwen, Arch.qwen_e):
768+
arch = resolve_arch(self.current_style, root.connection.client_if_connected)
769+
if arch.is_flux_like or arch == Arch.chroma or arch.is_qwen_like:
770770
self._open_folder(Path("models/diffusion_models"))
771771
else:
772772
self._open_folder(Path("models/checkpoints"))
@@ -883,6 +883,8 @@ def _enable_checkpoint_advanced(self):
883883
valid_archs = (Arch.auto, Arch.sdxl, Arch.illu, Arch.illu_v)
884884
elif arch.is_flux_like:
885885
valid_archs = (Arch.auto, Arch.flux, Arch.flux_k)
886+
elif arch.is_qwen_like:
887+
valid_archs = (Arch.auto, Arch.qwen, Arch.qwen_e, Arch.qwen_e_p)
886888
else:
887889
valid_archs = (Arch.auto, arch)
888890
with SignalBlocker(self._arch_select):

ai_diffusion/ui/theme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def checkpoint_icon(arch: Arch, format: FileFormat | None = None, client: Client
6868
return icon("sd-version-chroma")
6969
elif arch is Arch.qwen:
7070
return icon("sd-version-qwen")
71-
elif arch is Arch.qwen_e:
71+
elif arch in (Arch.qwen_e, Arch.qwen_e_p):
7272
return icon("sd-version-qwen-e")
7373
else:
7474
log.warning(f"Unresolved SD version {arch}, cannot fetch icon")

ai_diffusion/workflow.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,12 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
100100
case (FileFormat.diffusion, Quantization.none):
101101
model = w.load_diffusion_model(model_info.filename)
102102
case (FileFormat.diffusion, Quantization.svdq):
103-
if model_info.arch in (Arch.flux, Arch.flux_k):
103+
if model_info.arch.is_flux_like:
104104
cache = 0.12 if checkpoint.dynamic_caching else 0.0
105105
model = w.nunchaku_load_flux_diffusion_model(
106106
model_info.filename, cache_threshold=cache
107107
)
108-
elif model_info.arch in (Arch.qwen, Arch.qwen_e):
108+
elif model_info.arch.is_qwen_like:
109109
# WIP #2072 replace by customizable parameters
110110
model = w.nunchaku_load_qwen_diffusion_model(
111111
model_info.filename,
@@ -139,7 +139,7 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
139139
case Arch.chroma:
140140
clip = w.load_clip(te["t5"], type="chroma")
141141
clip = w.t5_tokenizer_options(clip, min_padding=1, min_length=0)
142-
case Arch.qwen | Arch.qwen_e:
142+
case Arch.qwen | Arch.qwen_e | Arch.qwen_e_p:
143143
clip = w.load_clip(te["qwen"], type="qwen_image")
144144
case _:
145145
raise RuntimeError(f"No text encoder for model architecture {arch.name}")
@@ -653,18 +653,28 @@ def apply_edit_conditioning(
653653

654654
extra_input = [c.image for c in control_layers if c.mode.is_ip_adapter]
655655
if len(extra_input) == 0:
656-
if arch == Arch.qwen_e:
656+
if arch == Arch.qwen_e_p:
657+
return w.text_encode_qwen_image_edit_plus(clip, vae, [input_image], positive)
658+
elif arch == Arch.qwen_e:
657659
# Don't use VAE to force the reference latent
658660
cond = w.text_encode_qwen_image_edit(clip, None, input_image, positive)
659661
return w.reference_latent(cond, input_latent)
660662

661-
input = w.image_stitch([input_image] + [i.load(w) for i in extra_input])
662-
latent = vae_encode(w, vae, input, tiled_vae)
663-
if arch == Arch.qwen_e:
664-
# Don't use VAE to force the reference latent
665-
cond = w.text_encode_qwen_image_edit(clip, None, input, positive)
666-
cond = w.reference_latent(cond, latent)
667-
return cond
663+
if arch == Arch.qwen_e_p:
664+
return w.text_encode_qwen_image_edit_plus(
665+
clip,
666+
vae,
667+
[input_image] + [i.load(w) for i in extra_input],
668+
positive,
669+
)
670+
else:
671+
input = w.image_stitch([input_image] + [i.load(w) for i in extra_input])
672+
latent = vae_encode(w, vae, input, tiled_vae)
673+
if arch == Arch.qwen_e:
674+
# Don't use VAE to force the reference latent
675+
cond = w.text_encode_qwen_image_edit(clip, None, input, positive)
676+
cond = w.reference_latent(cond, latent)
677+
return cond
668678

669679

670680
def scale(

0 commit comments

Comments
 (0)