-
Hello, I'm training DALL-E with:
Full script (from Jupyter Notebook, so 'tokenizer' variable exists): from random import choice
from pathlib import Path
# torch
import torch
from torch.optim import Adam
from torch.nn.utils import clip_grad_norm_
import gc
# vision imports
from PIL import Image
from torchvision import transforms as T
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image
# dalle related classes and utils
from dalle_pytorch import OpenAIDiscreteVAE, DiscreteVAE, DALLE
# helpers
def exists(val):
return val is not None
# constants
VAE_PATH = None # path to your trained discrete VAE
DALLE_PATH = None # path to your partially trained DALL-E
IMAGE_TEXT_FOLDER = "./dataset" # path to your folder of images and text for learning the DALL-E
RESUME = exists(DALLE_PATH)
EPOCHS = 10
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
GRAD_CLIP_NORM = 0.5
MODEL_DIM = 256
VOCAB_SIZE = 16384
TEXT_SEQ_LEN = 64
DEPTH = 12
HEADS = 8
DIM_HEAD = 16
IS_REVERSIBLE = True
# reconstitute vae
if RESUME:
dalle_path = Path(DALLE_PATH)
assert dalle_path.exists(), 'DALL-E model file does not exist'
loaded_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']
vae = DiscreteVAE(**vae_params)
dalle_params = dict(
vae = vae,
**dalle_params
)
IMAGE_SIZE = vae_params['image_size']
else:
if exists(VAE_PATH):
vae_path = Path(VAE_PATH)
assert vae_path.exists(), 'VAE model file does not exist'
loaded_obj = torch.load(str(vae_path))
vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']
vae = DiscreteVAE(**vae_params)
vae.load_state_dict(weights)
else:
print('using OpenAIs pretrained VAE for encoding images to tokens')
vae_params = None
vae = OpenAIDiscreteVAE()
IMAGE_SIZE = vae.image_size
dalle_params = dict(
vae = vae,
num_text_tokens = VOCAB_SIZE,
text_seq_len = TEXT_SEQ_LEN,
dim = MODEL_DIM,
depth = DEPTH,
heads = HEADS,
dim_head = DIM_HEAD,
attn_types = ('full', 'sparse'),
reversible = IS_REVERSIBLE
)
# helpers
def save_model(path):
save_obj = {
'hparams': dalle_params,
'vae_params': vae_params,
'weights': dalle.state_dict()
}
torch.save(save_obj, path)
# dataset loading
class TextImageDataset(Dataset):
def __init__(self, folder, text_len = 256, image_size = 128):
super().__init__()
path = Path(folder)
text_files = [*path.glob('**/*.txt')]
image_files = [
*path.glob('**/*.png'),
*path.glob('**/*.jpg'),
*path.glob('**/*.jpeg')
]
text_files = {t.stem: t for t in text_files}
image_files = {i.stem: i for i in image_files}
keys = (image_files.keys() & text_files.keys())
self.keys = list(keys)
self.text_files = {k: v for k, v in text_files.items() if k in keys}
self.image_files = {k: v for k, v in image_files.items() if k in keys}
self.image_tranform = T.Compose([
T.CenterCrop(image_size),
T.Resize(image_size),
T.ToTensor(),
T.Lambda(lambda t: t.expand(3, -1, -1)),
T.Normalize((0.5,) * 3, (0.5,) * 3)
])
def __len__(self):
return len(self.keys)
def __getitem__(self, ind):
key = self.keys[ind]
text_file = self.text_files[key]
image_file = self.image_files[key]
image = Image.open(image_file)
descriptions = text_file.read_text().split('\n')
descriptions = list(filter(lambda t: len(t) > 0, descriptions))
description = choice(descriptions)
ids = tokenizer.encode(description).ids
tokenized_text = torch.tensor(ids + [0] * (64 - len(ids)))
mask = tokenized_text != 0
image_tensor = self.image_tranform(image)
return tokenized_text, image_tensor, mask
# create dataset and dataloader
ds = TextImageDataset(
IMAGE_TEXT_FOLDER,
text_len = TEXT_SEQ_LEN,
image_size = IMAGE_SIZE
)
assert len(ds) > 0, 'dataset is empty'
print(f'{len(ds)} image-text pairs found for training')
dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)
# initialize DALL-E
dalle = DALLE(**dalle_params).cuda()
if RESUME:
dalle.load_state_dict(weights)
# optimizer
opt = Adam(dalle.parameters(), lr = LEARNING_RATE)
# experiment tracker
import wandb
wandb.init(project = 'kodalle', resume = RESUME)
wandb.config.depth = DEPTH
wandb.config.heads = HEADS
wandb.config.dim_head = DIM_HEAD
# training
for epoch in range(EPOCHS):
for i, (text, images, mask) in enumerate(dl):
text, images, mask = map(lambda t: t.cuda(), (text, images, mask))
loss = dalle(text, images, mask = mask, return_loss = True)
loss.backward()
clip_grad_norm_(dalle.parameters(), GRAD_CLIP_NORM)
opt.step()
opt.zero_grad()
log = {}
if i % 10 == 0:
print(epoch, i, f'loss - {loss.item()}')
log = {
**log,
'epoch': epoch,
'iter': i,
'loss': loss.item()
}
if i % 1000 == 0:
sample_text = text[:1]
token_list = sample_text.tolist()
decoded_text = tokenizer.decode(token_list[0])
image = dalle.generate_images(
text[:1],
mask = mask[:1],
filter_thres = 0.9 # topk sampling at 0.9
)
save_model(f'./dalle.pt')
wandb.save(f'./dalle.pt')
log = {
**log,
'image': wandb.Image(image, caption = decoded_text)
}
wandb.log(log)
del loss
del log
gc.collect()
save_model(f'./dalle-final.pt')
wandb.save('./dalle-final.pt') Now it's about 7 epochs passed... but the result is not so good: (Caption: A woman standing on a kitchen counter with canned vegetables.)Did I miss the points? Or is it a bug? |
Beta Was this translation helpful? Give feedback.
Answered by
kjsman
Mar 3, 2021
Replies: 1 comment
-
Likely because of #61 |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
kjsman
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Likely because of #61