Skip to content

Commit 4f438de

Browse files
authored
Add VisualCloze (#11377)
* VisualCloze * style quality * add docs * add docs * typo * Update docs/source/en/api/pipelines/visualcloze.md * delete einops * style quality * Update src/diffusers/pipelines/visualcloze/pipeline_visualcloze.py * reorg * refine doc * style quality * typo * typo * Update src/diffusers/image_processor.py * add comment * test * style * Modified based on review * style * restore image_processor * update example url * style * fix-copies * VisualClozeGenerationPipeline * combine * tests docs * remove VisualClozeUpsamplingPipeline * style * quality * test examples * quality style * typo * make fix-copies * fix test_callback_cfg and test_save_load_dduf in VisualClozePipelineFastTests * add EXAMPLE_DOC_STRING to VisualClozeGenerationPipeline * delete maybe_free_model_hooks from pipeline_visualcloze_combined * Apply suggestions from code review * fix test_save_load_local test; add reason for skipping cfg test * more save_load test fixes * fix tests in generation pipeline tests
1 parent 98cc6d0 commit 4f438de

13 files changed

+2694
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,8 @@
575575
title: UniDiffuser
576576
- local: api/pipelines/value_guided_sampling
577577
title: Value-guided sampling
578+
- local: api/pipelines/visualcloze
579+
title: VisualCloze
578580
- local: api/pipelines/wan
579581
title: Wan
580582
- local: api/pipelines/wuerstchen

docs/source/en/api/pipelines/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
8989
| [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation |
9090
| [Value-guided planning](value_guided_sampling) | value guided sampling |
9191
| [Wuerstchen](wuerstchen) | text2image |
92+
| [VisualCloze](visualcloze) | text2image, image2image, subject driven generation, inpainting, style transfer, image restoration, image editing, [depth,normal,edge,pose]2image, [depth,normal,edge,pose]-estimation, virtual try-on, image relighting |
9293

9394
## DiffusionPipeline
9495

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
-->
15+
16+
# VisualCloze
17+
18+
[VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning](https://arxiv.org/abs/2504.07960) is an innovative in-context learning based universal image generation framework that offers key capabilities:
19+
1. Support for various in-domain tasks
20+
2. Generalization to unseen tasks through in-context learning
21+
3. Unify multiple tasks into one step and generate both target image and intermediate results
22+
4. Support reverse-engineering conditions from target images
23+
24+
## Overview
25+
26+
The abstract from the paper is:
27+
28+
*Recent progress in diffusion models significantly advances various image generation tasks. However, the current mainstream approach remains focused on building task-specific models, which have limited efficiency when supporting a wide range of different needs. While universal models attempt to address this limitation, they face critical challenges, including generalizable task instruction, appropriate task distributions, and unified architectural design. To tackle these challenges, we propose VisualCloze, a universal image generation framework, which supports a wide range of in-domain tasks, generalization to unseen ones, unseen unification of multiple tasks, and reverse generation. Unlike existing methods that rely on language-based task instruction, leading to task ambiguity and weak generalization, we integrate visual in-context learning, allowing models to identify tasks from visual demonstrations. Meanwhile, the inherent sparsity of visual task distributions hampers the learning of transferable knowledge across tasks. To this end, we introduce Graph200K, a graph-structured dataset that establishes various interrelated tasks, enhancing task density and transferable knowledge. Furthermore, we uncover that our unified image generation formulation shared a consistent objective with image infilling, enabling us to leverage the strong generative priors of pre-trained infilling models without modifying the architectures. The codes, dataset, and models are available at https://visualcloze.github.io.*
29+
30+
## Inference
31+
32+
### Model loading
33+
34+
VisualCloze is a two-stage cascade pipeline, containing `VisualClozeGenerationPipeline` and `VisualClozeUpsamplingPipeline`.
35+
- In `VisualClozeGenerationPipeline`, each image is downsampled before concatenating images into a grid layout, avoiding excessively high resolutions. VisualCloze releases two models suitable for diffusers, i.e., [VisualClozePipeline-384](https://huggingface.co/VisualCloze/VisualClozePipeline-384) and [VisualClozePipeline-512](https://huggingface.co/VisualCloze/VisualClozePipeline-384), which downsample images to resolutions of 384 and 512, respectively.
36+
- `VisualClozeUpsamplingPipeline` uses [SDEdit](https://arxiv.org/abs/2108.01073) to enable high-resolution image synthesis.
37+
38+
The `VisualClozePipeline` integrates both stages to support convenient end-to-end sampling, while also allowing users to utilize each pipeline independently as needed.
39+
40+
### Input Specifications
41+
42+
#### Task and Content Prompts
43+
- Task prompt: Required to describe the generation task intention
44+
- Content prompt: Optional description or caption of the target image
45+
- When content prompt is not needed, pass `None`
46+
- For batch inference, pass `List[str|None]`
47+
48+
#### Image Input Format
49+
- Format: `List[List[Image|None]]`
50+
- Structure:
51+
- All rows except the last represent in-context examples
52+
- Last row represents the current query (target image set to `None`)
53+
- For batch inference, pass `List[List[List[Image|None]]]`
54+
55+
#### Resolution Control
56+
- Default behavior:
57+
- Initial generation in the first stage: area of ${pipe.resolution}^2$
58+
- Upsampling in the second stage: 3x factor
59+
- Custom resolution: Adjust using `upsampling_height` and `upsampling_width` parameters
60+
61+
### Examples
62+
63+
For comprehensive examples covering a wide range of tasks, please refer to the [Online Demo](https://huggingface.co/spaces/VisualCloze/VisualCloze) and [GitHub Repository](https://github.com/lzyhha/VisualCloze). Below are simple examples for three cases: mask-to-image conversion, edge detection, and subject-driven generation.
64+
65+
#### Example for mask2image
66+
67+
```python
68+
import torch
69+
from diffusers import VisualClozePipeline
70+
from diffusers.utils import load_image
71+
72+
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
73+
pipe.to("cuda")
74+
75+
# Load in-context images (make sure the paths are correct and accessible)
76+
image_paths = [
77+
# in-context examples
78+
[
79+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg'),
80+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg'),
81+
],
82+
# query with the target image
83+
[
84+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg'),
85+
None, # No image needed for the target image
86+
],
87+
]
88+
89+
# Task and content prompt
90+
task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
91+
content_prompt = """Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape.
92+
The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible.
93+
Its plumage is a mix of dark brown and golden hues, with intricate feather details.
94+
The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere.
95+
The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field,
96+
soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background,
97+
tranquil, majestic, wildlife photography."""
98+
99+
# Run the pipeline
100+
image_result = pipe(
101+
task_prompt=task_prompt,
102+
content_prompt=content_prompt,
103+
image=image_paths,
104+
upsampling_width=1344,
105+
upsampling_height=768,
106+
upsampling_strength=0.4,
107+
guidance_scale=30,
108+
num_inference_steps=30,
109+
max_sequence_length=512,
110+
generator=torch.Generator("cpu").manual_seed(0)
111+
).images[0][0]
112+
113+
# Save the resulting image
114+
image_result.save("visualcloze.png")
115+
```
116+
117+
#### Example for edge-detection
118+
119+
```python
120+
import torch
121+
from diffusers import VisualClozePipeline
122+
from diffusers.utils import load_image
123+
124+
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
125+
pipe.to("cuda")
126+
127+
# Load in-context images (make sure the paths are correct and accessible)
128+
image_paths = [
129+
# in-context examples
130+
[
131+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_image.jpg'),
132+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_edge.jpg'),
133+
],
134+
[
135+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_image.jpg'),
136+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_edge.jpg'),
137+
],
138+
# query with the target image
139+
[
140+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_query_image.jpg'),
141+
None, # No image needed for the target image
142+
],
143+
]
144+
145+
# Task and content prompt
146+
task_prompt = "Each row illustrates a pathway from [IMAGE1] a sharp and beautifully composed photograph to [IMAGE2] edge map with natural well-connected outlines using a clear logical task."
147+
content_prompt = ""
148+
149+
# Run the pipeline
150+
image_result = pipe(
151+
task_prompt=task_prompt,
152+
content_prompt=content_prompt,
153+
image=image_paths,
154+
upsampling_width=864,
155+
upsampling_height=1152,
156+
upsampling_strength=0.4,
157+
guidance_scale=30,
158+
num_inference_steps=30,
159+
max_sequence_length=512,
160+
generator=torch.Generator("cpu").manual_seed(0)
161+
).images[0][0]
162+
163+
# Save the resulting image
164+
image_result.save("visualcloze.png")
165+
```
166+
167+
#### Example for subject-driven generation
168+
169+
```python
170+
import torch
171+
from diffusers import VisualClozePipeline
172+
from diffusers.utils import load_image
173+
174+
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
175+
pipe.to("cuda")
176+
177+
# Load in-context images (make sure the paths are correct and accessible)
178+
image_paths = [
179+
# in-context examples
180+
[
181+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_reference.jpg'),
182+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_depth.jpg'),
183+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_image.jpg'),
184+
],
185+
[
186+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_reference.jpg'),
187+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_depth.jpg'),
188+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_image.jpg'),
189+
],
190+
# query with the target image
191+
[
192+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_reference.jpg'),
193+
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_depth.jpg'),
194+
None, # No image needed for the target image
195+
],
196+
]
197+
198+
# Task and content prompt
199+
task_prompt = """Each row describes a process that begins with [IMAGE1] an image containing the key object,
200+
[IMAGE2] depth map revealing gray-toned spatial layers and results in
201+
[IMAGE3] an image with artistic qualitya high-quality image with exceptional detail."""
202+
content_prompt = """A vintage porcelain collector's item. Beneath a blossoming cherry tree in early spring,
203+
this treasure is photographed up close, with soft pink petals drifting through the air and vibrant blossoms framing the scene."""
204+
205+
# Run the pipeline
206+
image_result = pipe(
207+
task_prompt=task_prompt,
208+
content_prompt=content_prompt,
209+
image=image_paths,
210+
upsampling_width=1024,
211+
upsampling_height=1024,
212+
upsampling_strength=0.2,
213+
guidance_scale=30,
214+
num_inference_steps=30,
215+
max_sequence_length=512,
216+
generator=torch.Generator("cpu").manual_seed(0)
217+
).images[0][0]
218+
219+
# Save the resulting image
220+
image_result.save("visualcloze.png")
221+
```
222+
223+
#### Utilize each pipeline independently
224+
225+
```python
226+
import torch
227+
from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline
228+
from diffusers.utils import load_image
229+
from PIL import Image
230+
231+
pipe = VisualClozeGenerationPipeline.from_pretrained(
232+
"VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
233+
)
234+
pipe.to("cuda")
235+
236+
image_paths = [
237+
# in-context examples
238+
[
239+
load_image(
240+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
241+
),
242+
load_image(
243+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
244+
),
245+
],
246+
# query with the target image
247+
[
248+
load_image(
249+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
250+
),
251+
None, # No image needed for the target image
252+
],
253+
]
254+
task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
255+
content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
256+
257+
# Stage 1: Generate initial image
258+
image = pipe(
259+
task_prompt=task_prompt,
260+
content_prompt=content_prompt,
261+
image=image_paths,
262+
guidance_scale=30,
263+
num_inference_steps=30,
264+
max_sequence_length=512,
265+
generator=torch.Generator("cpu").manual_seed(0),
266+
).images[0][0]
267+
268+
# Stage 2 (optional): Upsample the generated image
269+
pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe)
270+
pipe_upsample.to("cuda")
271+
272+
mask_image = Image.new("RGB", image.size, (255, 255, 255))
273+
274+
image = pipe_upsample(
275+
image=image,
276+
mask_image=mask_image,
277+
prompt=content_prompt,
278+
width=1344,
279+
height=768,
280+
strength=0.4,
281+
guidance_scale=30,
282+
num_inference_steps=30,
283+
max_sequence_length=512,
284+
generator=torch.Generator("cpu").manual_seed(0),
285+
).images[0]
286+
287+
image.save("visualcloze.png")
288+
```
289+
290+
## VisualClozePipeline
291+
292+
[[autodoc]] VisualClozePipeline
293+
- all
294+
- __call__
295+
296+
## VisualClozeGenerationPipeline
297+
298+
[[autodoc]] VisualClozeGenerationPipeline
299+
- all
300+
- __call__

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,8 @@
520520
"VersatileDiffusionPipeline",
521521
"VersatileDiffusionTextToImagePipeline",
522522
"VideoToVideoSDPipeline",
523+
"VisualClozeGenerationPipeline",
524+
"VisualClozePipeline",
523525
"VQDiffusionPipeline",
524526
"WanImageToVideoPipeline",
525527
"WanPipeline",
@@ -1100,6 +1102,8 @@
11001102
VersatileDiffusionPipeline,
11011103
VersatileDiffusionTextToImagePipeline,
11021104
VideoToVideoSDPipeline,
1105+
VisualClozeGenerationPipeline,
1106+
VisualClozePipeline,
11031107
VQDiffusionPipeline,
11041108
WanImageToVideoPipeline,
11051109
WanPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@
281281
_import_structure["mochi"] = ["MochiPipeline"]
282282
_import_structure["musicldm"] = ["MusicLDMPipeline"]
283283
_import_structure["omnigen"] = ["OmniGenPipeline"]
284+
_import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"]
284285
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
285286
_import_structure["pia"] = ["PIAPipeline"]
286287
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
@@ -727,6 +728,7 @@
727728
UniDiffuserPipeline,
728729
UniDiffuserTextDecoder,
729730
)
731+
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
730732
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
731733
from .wuerstchen import (
732734
WuerstchenCombinedPipeline,

0 commit comments

Comments
 (0)