Skip to content

Commit 9dd7d3d

Browse files
committed
finetuning code.
1 parent 74ae950 commit 9dd7d3d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2033
-228
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2022, salesforce.com, inc.
2+
# All rights reserved.
3+
# SPDX-License-Identifier: BSD-3-Clause
4+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
5+
6+
datasets:
7+
blip_diffusion_finetune: # name of the dataset builder
8+
# data_dir: ${env.data_dir}/datasets
9+
data_type: images # [images|videos|features]
10+
11+
build_info:
12+
# Be careful not to append minus sign (-) before split to avoid itemizing
13+
images:
14+
storage: ""

lavis/configs/models/blip-diffusion/blip_diffusion_base.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ model:
1414
preprocess:
1515
vis_processor:
1616
train:
17-
name: "blip_diffusion_image_eval"
17+
name: "blip_diffusion_inp_image_eval"
1818
eval:
19-
name: "blip_diffusion_image_eval"
19+
name: "blip_diffusion_inp_image_eval"
2020
text_processor:
2121
train:
2222
name: "blip_caption"

lavis/datasets/builders/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@
3737
Flickr30kBuilder,
3838
)
3939
from lavis.datasets.builders.dialogue_builder import AVSDDialBuilder
40+
from lavis.datasets.builders.text_to_image_generation_builder import BlipDiffusionFinetuneBuilder
4041

4142
from lavis.common.registry import registry
4243

4344
__all__ = [
45+
"BlipDiffusionFinetuneBuilder",
4446
"COCOCapBuilder",
4547
"COCORetrievalBuilder",
4648
"COCOVQABuilder",

lavis/datasets/builders/base_dataset_builder.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ def __init__(self, cfg=None):
4040
self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
4141
self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
4242

43+
# additional processors, each specified by a name in string.
44+
self.kw_processors = {}
45+
4346
def build_datasets(self):
4447
# download, split, etc...
4548
# only called on 1 GPU/TPU in distributed
@@ -73,7 +76,12 @@ def build_processors(self):
7376

7477
self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
7578
self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
76-
79+
80+
kw_proc_cfg = self.config.get("kw_processor")
81+
if kw_proc_cfg is not None:
82+
for name, cfg in kw_proc_cfg.items():
83+
self.kw_processors[name] = self._build_proc_from_cfg(cfg)
84+
7785
@staticmethod
7886
def _build_proc_from_cfg(cfg):
7987
return (
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
Copyright (c) 2022, salesforce.com, inc.
3+
All rights reserved.
4+
SPDX-License-Identifier: BSD-3-Clause
5+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
"""
7+
8+
from lavis.common.registry import registry
9+
from lavis.datasets.datasets.subject_driven_t2i_dataset import (
10+
SubjectDrivenTextToImageDataset,
11+
)
12+
from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder
13+
14+
15+
@registry.register_builder("blip_diffusion_finetune")
16+
class BlipDiffusionFinetuneBuilder(BaseDatasetBuilder):
17+
train_dataset_cls = SubjectDrivenTextToImageDataset
18+
19+
DATASET_CONFIG_DICT = {
20+
"default": "configs/datasets/blip_diffusion_datasets/defaults.yaml"
21+
}
22+
23+
def _download_ann(self):
24+
pass
25+
26+
def build(self):
27+
self.build_processors()
28+
29+
build_info = self.config.build_info
30+
31+
dataset = self.train_dataset_cls(
32+
image_dir=build_info.images.storage,
33+
subject_text=build_info.subject_text,
34+
inp_image_processor=self.kw_processors["inp_vis_processor"],
35+
tgt_image_processor=self.kw_processors["tgt_vis_processor"],
36+
txt_processor=self.text_processors["eval"],
37+
)
38+
39+
return {"train": dataset}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""
2+
Copyright (c) 2022, salesforce.com, inc.
3+
All rights reserved.
4+
SPDX-License-Identifier: BSD-3-Clause
5+
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
"""
7+
8+
import os
9+
10+
from PIL import Image
11+
from torch.utils.data import Dataset
12+
from torch.utils.data.dataloader import default_collate
13+
14+
15+
class SubjectDrivenTextToImageDataset(Dataset):
16+
def __init__(
17+
self,
18+
image_dir,
19+
subject_text,
20+
inp_image_processor,
21+
tgt_image_processor,
22+
txt_processor,
23+
repetition=100000,
24+
):
25+
self.subject = txt_processor(subject_text.lower())
26+
self.image_dir = image_dir
27+
28+
self.inp_image_transform = inp_image_processor
29+
self.tgt_image_transform = tgt_image_processor
30+
31+
self.text_processor = txt_processor
32+
33+
image_paths = os.listdir(image_dir)
34+
# image paths are jpg png webp
35+
image_paths = [
36+
os.path.join(image_dir, imp)
37+
for imp in image_paths
38+
if os.path.splitext(imp)[1][1:]
39+
in ["jpg", "png", "webp", "jpeg", "JPG", "PNG", "WEBP", "JPEG"]
40+
]
41+
# make absolute path
42+
self.image_paths = [os.path.abspath(imp) for imp in image_paths]
43+
self.repetition = repetition
44+
45+
def __len__(self):
46+
return len(self.image_paths) * self.repetition
47+
48+
@property
49+
def len_without_repeat(self):
50+
return len(self.image_paths)
51+
52+
def collater(self, samples):
53+
return default_collate(samples)
54+
55+
def __getitem__(self, index):
56+
image_path = self.image_paths[index % len(self.image_paths)]
57+
image = Image.open(image_path).convert("RGB")
58+
59+
# For fine-tuning, we use the same caption for all images
60+
# maybe worth trying different captions for different images
61+
caption = f"a {self.subject}"
62+
caption = self.text_processor(caption)
63+
64+
inp_image = self.inp_image_transform(image)
65+
tgt_image = self.tgt_image_transform(image)
66+
67+
return {
68+
"inp_image": inp_image,
69+
"tgt_image": tgt_image,
70+
"caption": caption,
71+
"subject_text": self.subject,
72+
}

lavis/models/base_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ def load_checkpoint_from_config(self, cfg, **kwargs):
101101
assert "Found load_finetuned is False, but pretrain_path is None."
102102
self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
103103

104+
def before_training(self, **kwargs):
105+
pass
104106

105107
def before_evaluation(self, **kwargs):
106108
pass

0 commit comments

Comments
 (0)