Skip to content

[WIP] PixArt-Sigma training pipeline #1341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions library/nets/PA_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch

from mmcv import Registry
MODELS = Registry('models')

from collections.abc import Iterable
from itertools import repeat
from torch.utils.checkpoint import checkpoint

def _ntuple(n):
def parse(x):
if isinstance(x, Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, n))
return parse

to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)

def checkpoint_sequential(functions, step, input, *args, **kwargs):

# Hack for keyword-only parameter in a python 2.7-compliant way
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

def run_function(start, end, functions):
def forward(input):
for j in range(start, end + 1):
input = functions[j](input, *args)
return input
return forward

if isinstance(functions, torch.nn.Sequential):
functions = list(functions.children())

# the last chunk has to be non-volatile
end = -1
segment = len(functions) // step
for start in range(0, step * (segment - 1), step):
end = start + step - 1
input = checkpoint(run_function(start, end, functions), input, preserve_rng_state=preserve)
return run_function(end + 1, len(functions) - 1, functions)(input)


def auto_grad_checkpoint(module, *args, **kwargs):
if getattr(module, 'grad_checkpointing', False):
if isinstance(module, Iterable):
gc_step = module[0].grad_checkpointing_step
return checkpoint_sequential(module, gc_step, *args, **kwargs)
else:
return checkpoint(module, *args, **kwargs)
return module(*args, **kwargs)
313 changes: 313 additions & 0 deletions library/nets/PixArt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,313 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import math
import torch
import torch.nn as nn
import os
import numpy as np
from timm.models.layers import DropPath
from timm.models.vision_transformer import PatchEmbed, Mlp

from .PA_utils import MODELS, to_2tuple, auto_grad_checkpoint

from .PixArt_blocks import t2i_modulate, CaptionEmbedder, AttentionKVCompress, MultiHeadCrossAttention, T2IFinalLayer, TimestepEmbedder, LabelEmbedder, FinalLayer

class PixArtBlock(nn.Module):
"""
A PixArt block with adaptive layer norm (adaLN-single) conditioning.
"""

def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None,
sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = AttentionKVCompress(
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
qk_norm=qk_norm, **block_kwargs
)
self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# to be compatible with lower version pytorch
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
self.sampling = sampling
self.sr_ratio = sr_ratio

def forward(self, x, y, t, mask=None, **kwargs):
B, N, C = x.shape

shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C))
x = x + self.cross_attn(x, y, mask)
x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))

return x


#############################################################################
# Core PixArt Model #
#################################################################################
@MODELS.register_module()
class PixArt(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""

def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
pred_sigma=True,
drop_path: float = 0.,
caption_channels=4096,
pe_interpolation=1.0,
config=None,
model_max_length=120,
qk_norm=False,
kv_compress_config=None,
**kwargs,
):
super().__init__()
self.pred_sigma = pred_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if pred_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.pe_interpolation = pe_interpolation
self.depth = depth

self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
self.t_embedder = TimestepEmbedder(hidden_size)
num_patches = self.x_embedder.num_patches
self.base_size = input_size // self.patch_size
# Will use fixed sin-cos embedding:
self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))

approx_gelu = lambda: nn.GELU(approximate="tanh")
self.t_block = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.y_embedder = CaptionEmbedder(
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
act_layer=approx_gelu, token_num=model_max_length)
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
self.kv_compress_config = kv_compress_config
if kv_compress_config is None:
self.kv_compress_config = {
'sampling': None,
'scale_factor': 1,
'kv_compress_layer': [],
}
self.blocks = nn.ModuleList([
PixArtBlock(
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
input_size=(input_size // patch_size, input_size // patch_size),
sampling=self.kv_compress_config['sampling'],
sr_ratio=int(
self.kv_compress_config['scale_factor']
) if i in self.kv_compress_config['kv_compress_layer'] else 1,
qk_norm=qk_norm,
)
for i in range(depth)
])
self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels)

self.initialize_weights()

# if config:
# logger = get_root_logger(os.path.join(config.work_dir, 'train_log.log'))
# logger.warning(f"position embed interpolation: {self.pe_interpolation}, base size: {self.base_size}")
# logger.warning(f"kv compress config: {self.kv_compress_config}")
# else:
# print(f'Warning: position embed interpolation: {self.pe_interpolation}, base size: {self.base_size}')
# print(f"kv compress config: {self.kv_compress_config}")


def forward(self, x, timestep, y, mask=None, data_info=None, **kwargs):
"""
Forward pass of PixArt.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, 1, 120, C) tensor of class labels
"""
x = x.to(self.dtype)
timestep = timestep.to(self.dtype)
y = y.to(self.dtype)
pos_embed = self.pos_embed.to(self.dtype)
self.h, self.w = x.shape[-2]//self.patch_size, x.shape[-1]//self.patch_size
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(timestep.to(x.dtype)) # (N, D)
t0 = self.t_block(t)
y = self.y_embedder(y, self.training) # (N, 1, L, D)
if mask is not None:
if mask.shape[0] != y.shape[0]:
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
mask = mask.squeeze(1).squeeze(1)
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
y_lens = mask.sum(dim=1).tolist()
else:
y_lens = [y.shape[2]] * y.shape[0]
y = y.squeeze(1).view(1, -1, x.shape[-1])
for block in self.blocks:
x = auto_grad_checkpoint(block, x, y, t0, y_lens) # (N, T, D) #support grad checkpoint
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
x = self.unpatchify(x) # (N, out_channels, H, W)
return x

def forward_with_dpmsolver(self, x, timestep, y, mask=None, **kwargs):
"""
dpm solver donnot need variance prediction
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
model_out = self.forward(x, timestep, y, mask)
return model_out.chunk(2, dim=1)[0]

def forward_with_cfg(self, x, timestep, y, cfg_scale, mask=None, **kwargs):
"""
Forward pass of PixArt, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.forward(combined, timestep, y, mask, kwargs)
model_out = model_out['x'] if isinstance(model_out, dict) else model_out
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)

def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]

x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs

def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)

self.apply(_basic_init)

# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5),
pe_interpolation=self.pe_interpolation, base_size=self.base_size
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))

# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
nn.init.normal_(self.t_block[1].weight, std=0.02)

# Initialize caption embedding MLP:
nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02)
nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02)

# Zero-out adaLN modulation layers in PixArt blocks:
for block in self.blocks:
nn.init.constant_(block.cross_attn.proj.weight, 0)
nn.init.constant_(block.cross_attn.proj.bias, 0)

# Zero-out output layers:
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)

@property
def dtype(self):
return next(self.parameters()).dtype


def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = to_2tuple(grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0]/base_size) / pe_interpolation
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1]/base_size) / pe_interpolation
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])

pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0

# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)

emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)

pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product

emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)

emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb


#################################################################################
# PixArt Configs #
#################################################################################
@MODELS.register_module()
def PixArt_XL_2(**kwargs):
return PixArt(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
Loading
Loading