Skip to content

Commit 9199552

Browse files
authored
Merge pull request #5 from mkshing/v0.2.0
v0.2.0
2 parents 7f47d31 + c645374 commit 9199552

20 files changed

+3727
-386
lines changed

README.md

Lines changed: 90 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@ My summary tweet is found [here](https://twitter.com/mk1stats/status/16428655051
1212
left: LoRA, right: SVDiff
1313

1414

15-
Compared with LoRA, the number of trainable parameters is 0.6 M less parameters and the file size is only <1MB (LoRA: 3.1MB)!!
15+
Compared with LoRA, the number of trainable parameters is 0.5 M less parameters and the file size is only 1.2MB (LoRA: 3.1MB)!!
1616

1717
![kumamon](assets/kumamon.png)
1818

19+
## Updates
20+
### 2023.4.11
21+
- Released v0.2.0 (please see [here](https://github.com/mkshing/svdiff-pytorch/releases/tag/v0.2.0) for the details)
22+
- Add [Single Image Editing](#single-image-editing)
23+
![chair-result](assets/chair-result.png)
24+
<br>"photo of a ~~pink~~ blue chair with black legs"
25+
1926
## Installation
2027
```
2128
$ pip install svdiff-pytorch
@@ -26,9 +33,10 @@ $ git clone https://github.com/mkshing/svdiff-pytorch
2633
$ pip install -r requirements.txt
2734
```
2835

29-
## Training
30-
The following example script is for "Single-Subject Generation", which is a domain-tuning on a single object or concept (using 3-5 images). (See Section 4.1)
36+
## Single-Subject Generation
37+
"Single-Subject Generation" is a domain-tuning on a single object or concept (using 3-5 images). (See Section 4.1)
3138

39+
### Training
3240
According to the paper, the learning rate for SVDiff needs to be 1000 times larger than the lr used for fine-tuning.
3341

3442
```bash
@@ -48,29 +56,32 @@ accelerate launch train_svdiff.py \
4856
--resolution=512 \
4957
--train_batch_size=1 \
5058
--gradient_accumulation_steps=1 \
51-
--learning_rate=5e-3 \
59+
--learning_rate=1e-3 \
60+
--learning_rate_1d=1e-6 \
61+
--train_text_encoder \
5262
--lr_scheduler="constant" \
5363
--lr_warmup_steps=0 \
5464
--num_class_images=200 \
55-
--max_train_steps=800
65+
--max_train_steps=500
5666
```
5767

58-
59-
## Inference
68+
### Inference
6069

6170
```python
6271
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
6372
import torch
6473

65-
from svdiff_pytorch import load_unet_for_svdiff
74+
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff
6675

6776
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
68-
spectral_shifts_ckpt = "spectral_shifts.safetensors-path"
69-
unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt, subfolder="unet")
77+
spectral_shifts_ckpt_dir = "ckpt-dir-path"
78+
unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="unet")
79+
text_encoder = load_text_encoder_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="text_encoder")
7080
# load pipe
7181
pipe = StableDiffusionPipeline.from_pretrained(
7282
pretrained_model_name_or_path,
7383
unet=unet,
84+
text_encoder=text_encoder,
7485
)
7586
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
7687
pipe.to("cuda")
@@ -82,14 +93,14 @@ You can use the following CLI too! Once it's done, you will see `grid.png` for t
8293
```bash
8394
python inference.py \
8495
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
85-
--spectral_shifts_ckpt="spectral_shifts.safetensors-path" \
96+
--spectral_shifts_ckpt="ckpt-dir-path" \
8697
--prompt="A picture of a sks dog in a bucket" \
8798
--scheduler_type="dpm_solver++" \
8899
--num_inference_steps=25 \
89100
--num_images_per_prompt=2
90101
```
91102

92-
## Gradio
103+
### Gradio
93104
You can also try SVDiff-pytorch in a UI with [gradio](https://gradio.app/). This demo supports both training and inference!
94105

95106
[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/svdiff-library/SVDiff-Training-UI)
@@ -103,7 +114,73 @@ $ export HF_TOKEN="YOUR_HF_TOKEN_HERE"
103114
$ python app.py
104115
```
105116

117+
## Single Image Editing
118+
### Training
119+
In Single Image Editing, your instance prompt should be just the description of your input image **without the identifier**.
120+
121+
```bash
122+
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
123+
export INSTANCE_DIR="dir-path-to-input-image"
124+
export CLASS_DIR="path-to-class-images"
125+
export OUTPUT_DIR="path-to-save-model"
126+
127+
accelerate launch train_svdiff.py \
128+
--pretrained_model_name_or_path=$MODEL_NAME \
129+
--instance_data_dir=$INSTANCE_DIR \
130+
--class_data_dir=$CLASS_DIR \
131+
--output_dir=$OUTPUT_DIR \
132+
--with_prior_preservation --prior_loss_weight=1.0 \
133+
--instance_prompt="photo of a pink chair with black legs" \
134+
--class_prompt="photo of a chair" \
135+
--resolution=512 \
136+
--train_batch_size=1 \
137+
--gradient_accumulation_steps=1 \
138+
--learning_rate=1e-3 \
139+
--learning_rate_1d=1e-6 \
140+
--train_text_encoder \
141+
--lr_scheduler="constant" \
142+
--lr_warmup_steps=0 \
143+
--num_class_images=200 \
144+
--max_train_steps=500
145+
```
146+
147+
### Inference
148+
149+
```python
150+
import torch
151+
from PIL import Image
152+
from diffusers import DDIMScheduler
153+
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, StableDiffusionPipelineWithDDIMInversion
154+
155+
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
156+
spectral_shifts_ckpt_dir = "ckpt-dir-path"
157+
image = "path-to-image"
158+
source_prompt = "prompt-for-image"
159+
target_prompt = "prompt-you-want-to-generate"
160+
161+
unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="unet")
162+
text_encoder = load_text_encoder_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="text_encoder")
163+
# load pipe
164+
pipe = StableDiffusionPipelineWithDDIMInversion.from_pretrained(
165+
pretrained_model_name_or_path,
166+
unet=unet,
167+
text_encoder=text_encoder,
168+
)
169+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
170+
pipe.to("cuda")
171+
172+
# (optional) ddim inversion
173+
# if you don't do it, inv_latents = None
174+
image = Image.open(image).convert("RGB").resize((512, 512))
175+
# in SVDiff, they use guidance scale=1 in ddim inversion
176+
inv_latents = pipe.invert(source_prompt, image=image, guidance_scale=1.0).latents
177+
178+
image = pipe(target_prompt, latents=inv_latents).images[0]
179+
```
180+
181+
106182
## Additional Features
183+
107184
### Spectral Shift Scaling
108185

109186
![scale](assets/scale.png)
@@ -165,6 +242,7 @@ And, add `--enable_tome_merging` to your training arguments!
165242
- [x] Training
166243
- [x] Inference
167244
- [x] Scaling spectral shifts
245+
- [x] Support Single Image Editing
168246
- [ ] Support multiple spectral shifts (Section 3.2)
169247
- [ ] Cut-Mix-Unmix (Section 3.3)
170248
- [ ] SVDiff + LoRA

assets/chair-result.png

490 KB
Loading

inference.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import argparse
2+
import os
23
from tqdm import tqdm
34
import random
45
import torch
6+
import huggingface_hub
7+
from transformers import CLIPTextModel
58
from diffusers import StableDiffusionPipeline
69
from diffusers.utils import is_xformers_available
7-
from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING, image_grid
10+
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING, image_grid
811

912

1013
def parse_args():
@@ -14,7 +17,7 @@ def parse_args():
1417
# diffusers config
1518
parser.add_argument("--prompt", type=str, nargs="?", default="a photo of *s", help="the prompt to render")
1619
parser.add_argument("--num_inference_steps", type=int, default=50, help="number of sampling steps")
17-
parser.add_argument("--guidance_scale", type=float, default=1.0, help="unconditional guidance scale")
20+
parser.add_argument("--guidance_scale", type=float, default=7.5, help="unconditional guidance scale")
1821
parser.add_argument("--num_images_per_prompt", type=int, default=1, help="number of images per prompt")
1922
parser.add_argument("--height", type=int, default=512, help="image height, in pixel space",)
2023
parser.add_argument("--width", type=int, default=512, help="image width, in pixel space",)
@@ -27,6 +30,33 @@ def parse_args():
2730
return args
2831

2932

33+
def load_text_encoder(pretrained_model_name_or_path, spectral_shifts_ckpt, device, fp16=False):
34+
if os.path.isdir(spectral_shifts_ckpt):
35+
spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts_te.safetensors")
36+
elif not os.path.exists(spectral_shifts_ckpt):
37+
# download from hub
38+
hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
39+
try:
40+
spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts_te.safetensors", **hf_hub_kwargs)
41+
except huggingface_hub.utils.EntryNotFoundError:
42+
return CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16 if fp16 else None).to(device)
43+
if not os.path.exists(spectral_shifts_ckpt):
44+
return CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16 if fp16 else None).to(device)
45+
text_encoder = load_text_encoder_for_svdiff(
46+
pretrained_model_name_or_path=pretrained_model_name_or_path,
47+
spectral_shifts_ckpt=spectral_shifts_ckpt,
48+
subfolder="text_encoder",
49+
)
50+
# first perform svd and cache
51+
for module in text_encoder.modules():
52+
if hasattr(module, "perform_svd"):
53+
module.perform_svd()
54+
if fp16:
55+
text_encoder = text_encoder.to(device, dtype=torch.float16)
56+
return text_encoder
57+
58+
59+
3060
def main():
3161
args = parse_args()
3262
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -40,10 +70,18 @@ def main():
4070
module.perform_svd()
4171
if args.fp16:
4272
unet = unet.to(device, dtype=torch.float16)
73+
text_encoder = load_text_encoder(
74+
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
75+
spectral_shifts_ckpt=args.spectral_shifts_ckpt,
76+
fp16=args.fp16,
77+
device=device
78+
)
79+
4380
# load pipe
4481
pipe = StableDiffusionPipeline.from_pretrained(
4582
args.pretrained_model_name_or_path,
4683
unet=unet,
84+
text_encoder=text_encoder,
4785
requires_safety_checker=False,
4886
safety_checker=None,
4987
feature_extractor=None,
@@ -67,6 +105,11 @@ def main():
67105
for module in pipe.unet.modules():
68106
if hasattr(module, "set_scale"):
69107
module.set_scale(scale=args.spectral_shifts_scale)
108+
if not isinstance(pipe.text_encoder, CLIPTextModel):
109+
for module in pipe.text_encoder.modules():
110+
if hasattr(module, "set_scale"):
111+
module.set_scale(scale=args.spectral_shifts_scale)
112+
70113
print(f"Set spectral_shifts_scale to {args.spectral_shifts_scale}!")
71114

72115
if args.seed == "random_seed":

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ diffusers==0.14.0
22
accelerate
33
torchvision
44
safetensors
5-
transformers>=4.25.1
5+
transformers>=4.25.1, <=4.27.3
66
ftfy
77
tensorboard
88
Jinja2

0 commit comments

Comments
 (0)