diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index fe294465..76c2e254 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -55,6 +55,20 @@ def top_k(logits, thres = 0.5): probs.scatter_(1, ind, val) return probs +class SharedEmbedding(nn.Embedding): + def __init__(self, linear, start_index, end_index, **kwargs): + super().__init__(end_index - start_index, linear.weight.shape[1], **kwargs) + del self.weight + + self.linear = linear + self.start_index = start_index + self.end_index = end_index + + def forward(self, input): + return F.embedding( + input, self.linear.weight[self.start_index:self.end_index], self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + # discrete vae class class ResBlock(nn.Module): @@ -326,7 +340,10 @@ def __init__( stable = False, sandwich_norm = False, shift_tokens = True, - rotary_emb = True + rotary_emb = True, + shared_attn_ids = None, + shared_ff_ids = None, + share_input_output_emb = False, ): super().__init__() assert isinstance(vae, (DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE)), 'vae must be an instance of DiscreteVAE' @@ -338,9 +355,6 @@ def __init__( num_text_tokens = num_text_tokens + text_seq_len # reserve unique padding tokens for each position (text seq len) - self.text_emb = nn.Embedding(num_text_tokens, dim) - self.image_emb = nn.Embedding(num_image_tokens, dim) - self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) if not rotary_emb else always(0) # +1 for self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_fmap_size, image_fmap_size)) if not rotary_emb else always(0) @@ -374,7 +388,9 @@ def __init__( stable = stable, sandwich_norm = sandwich_norm, shift_tokens = shift_tokens, - rotary_emb = rotary_emb + rotary_emb = rotary_emb, + shared_attn_ids = shared_attn_ids, + shared_ff_ids = shared_ff_ids, ) self.stable = stable @@ -387,6 +403,13 @@ def __init__( nn.Linear(dim, self.total_tokens), ) + if share_input_output_emb: + self.text_emb = SharedEmbedding(self.to_logits[1], 0, num_text_tokens) + self.image_emb = SharedEmbedding(self.to_logits[1], num_text_tokens, total_tokens) + else: + self.text_emb = nn.Embedding(num_text_tokens, dim) + self.image_emb = nn.Embedding(num_image_tokens, dim) + seq_range = torch.arange(seq_len) logits_range = torch.arange(total_tokens) @@ -417,7 +440,7 @@ def generate_texts( text_tokens = torch.tensor([[0]]).cuda() else: text_tokens = torch.tensor(tokenizer.tokenizer.encode(text)).cuda().unsqueeze(0) - + for _ in range(text_tokens.shape[1], text_seq_len): device = text_tokens.device @@ -443,9 +466,9 @@ def generate_texts( filtered_logits = top_k(logits, thres = filter_thres) probs = F.softmax(filtered_logits / temperature, dim = -1) sample = torch.multinomial(probs, 1) - + text_tokens = torch.cat((text_tokens, sample), dim=-1) - + padding_tokens = set(np.arange(self.text_seq_len) + (self.num_text_tokens - self.text_seq_len)) texts = [tokenizer.tokenizer.decode(text_token, pad_tokens=padding_tokens) for text_token in text_tokens] return text_tokens, texts diff --git a/dalle_pytorch/transformer.py b/dalle_pytorch/transformer.py index 81bc7b4a..c7322a4c 100644 --- a/dalle_pytorch/transformer.py +++ b/dalle_pytorch/transformer.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from functools import partial from itertools import islice, cycle @@ -21,9 +22,7 @@ def default(val, d): return val if exists(val) else d def cast_tuple(val, depth = 1): - if isinstance(val, list): - val = tuple(val) - return val if isinstance(val, tuple) else (val,) * depth + return val if isinstance(val, Iterable) else (val,) * depth # classes @@ -150,7 +149,9 @@ def __init__( stable = False, sandwich_norm = False, shift_tokens = False, - rotary_emb = True + rotary_emb = True, + shared_attn_ids = None, + shared_ff_ids = None, ): super().__init__() layers = nn.ModuleList([]) @@ -160,7 +161,13 @@ def __init__( attn_types = cast_tuple(attn_types) attn_type_layer = islice(cycle(attn_types), depth) - for ind, sparse_attn, attn_type in zip(range(depth), sparse_layer, attn_type_layer): + shared_attn_ids = cycle(default(shared_attn_ids, range(depth))) + shared_ff_ids = cycle(default(shared_ff_ids, range(depth))) + shared_attn_layers = {} + shared_ff_layers = {} + + for (ind, sparse_attn, attn_type, attn_id, ff_id) in \ + zip(range(depth), sparse_layer, attn_type_layer, shared_attn_ids, shared_ff_ids): if attn_type == 'full': attn_class = partial(Attention, stable = stable) elif attn_type == 'sparse': @@ -176,12 +183,21 @@ def __init__( else: raise ValueError(f'attention type "{attn_type}" is not valid') - if attn_type != 'mlp': - attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) - else: - attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4) - - ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) + attn, reused_attn_type = shared_attn_layers.get(attn_id, (None, None)) + if not exists(attn): + if attn_type != 'mlp': + attn = attn_class(dim, causal = causal, seq_len = seq_len, heads = heads, dim_head = dim_head, dropout = attn_dropout) + else: + attn = attn_class(dim = dim, causal = causal, dim_ff = dim * 4) + shared_attn_layers[attn_id] = (attn, attn_type) + elif attn_type != reused_attn_type: + raise ValueError('attn_types do not match shared_attn_ids ' + f'(ind = {ind}, attn_type = "{attn_type}", reused_attn_type = "{reused_attn_type}")') + + ff = shared_ff_layers.get(ff_id) + if not exists(ff): + ff = FeedForward(dim, mult = ff_mult, dropout = ff_dropout) + shared_ff_layers[ff_id] = ff if shift_tokens: attn, ff = map(lambda t: PreShiftToken(t, image_size = image_fmap_size, seq_len = seq_len), (attn, ff)) diff --git a/train_dalle.py b/train_dalle.py index 7f07079c..0eaa11c9 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -46,9 +46,9 @@ help='path to your folder of images and text for learning the DALL-E') parser.add_argument( - '--wds', - type = str, - default='', + '--wds', + type = str, + default='', help = 'Comma separated list of WebDataset (1) image and (2) text column names. Must contain 2 values, e.g. img,cap.' ) @@ -134,6 +134,10 @@ model_group.add_argument('--rotary_emb', help = 'Use rotary embeddings', action = 'store_true') +model_group.add_argument('--shared_attn_ids', default = None, type = str, help = 'Comma separated list of shared attention layer ids. Default: sharing is disabled') + +model_group.add_argument('--shared_ff_ids', default = None, type = str, help = 'Comma separated list of shared feed forward layer ids. Default: sharing is disabled') + args = parser.parse_args() # helpers @@ -191,6 +195,8 @@ def cp_path_to_dir(cp_path, tag): ROTARY_EMB = args.rotary_emb ATTN_TYPES = tuple(args.attn_types.split(',')) +SHARED_ATTN_IDS = tuple(args.shared_attn_ids.split(',')) if exists(args.shared_attn_ids) else None +SHARED_FF_IDS = tuple(args.shared_ff_ids.split(',')) if exists(args.shared_ff_ids) else None DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt' @@ -303,6 +309,8 @@ def cp_path_to_dir(cp_path, tag): stable=STABLE, shift_tokens=SHIFT_TOKENS, rotary_emb=ROTARY_EMB, + shared_attn_ids=SHARED_ATTN_IDS, + shared_ff_ids=SHARED_FF_IDS, ) resume_epoch = 0 @@ -368,7 +376,7 @@ def filter_dataset(item): # For e.g. C@H which (rarely) has no caption available if myimg not in item: return False return True - + w_dataset = wds.WebDataset(DATASET, handler=wds.warn_and_continue) filtered_dataset = w_dataset.select(filter_dataset) ds = filtered_dataset.map_dict(**image_text_mapping).map_dict(**image_mapping).to_tuple(mycap, myimg).batched(BATCH_SIZE, partial=True) @@ -600,7 +608,7 @@ def save_model(path, epoch=0): if i % SAVE_EVERY_N_STEPS == 0: save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch) - + if i % 100 == 0: if distr_backend.is_root_worker(): sample_text = text[:1] @@ -633,7 +641,7 @@ def save_model(path, epoch=0): distr_scheduler.step(avg_loss) save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch) - + if distr_backend.is_root_worker(): # save trained model to wandb as an artifact every epoch's end