This repository is the official implementation of CELL-E: Biological Zero-Shot Text-to-Image Synthesis for Protein Localization Prediction.
Create a virtual environment and install the required packages via:
pip install -r requirements.txt
Next, install torch = 1.7.1 and torchvision==0.8.2 with the appropriate CUDA version
We used OpenCell for CELL-E, which has information on downloading the entire dataset. A data_csv is needed to for the dataloader. You must generate a csv file which contains the columns nucleus_image_path, protein_image_path, metadata_path, and split (train or val). It is assumed that this file exists within the the same general data folder as the images and metadata files.
Metadata is a JSON which should accompany every protein sequence. If a sequence does not appear in the data_csv, it must appear in metadata.json with the a key named protein_sequence.
Adding more information here can be useful for querying individual proteins. They can be retrieved via retrieve_metadata, which creates a self.metadata variable within the dataset object.
Training for CELL-E occurs in 2 (or 3) stages:
- Training Protein Threshold Image encoder
- (Optional, but recommended) Training a Nucleus Image (Conditional Image) Encoder
- Training CELL-E Transformer
There are two available image encoders in this repository: Discrete VAE (Similar to the original OpenAI implementation) and VQGAN (recommended). If using the protein threshold image, set threshold: True for the dataset.
The discrete VAE can be trained using the following code:
from celle import DiscreteVAE
vae = DiscreteVAE(
image_size = 256,
num_layers = 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
channels = 1,
num_tokens = 512, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
codebook_dim = 512, # codebook dimension
hidden_dim = 64, # hidden dimension
num_resnet_blocks = 1, # number of resnet blocks
temperature = 0.9, # gumbel softmax temperature, the lower this is, the harder the discretization
straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other)
loss = vae(images, return_loss = True)
loss.backward()We use a slightly modified version of the taming-transformers code.
To train, run the following script:
python celle_taming_main.py --base configs/threshold_vqgan.yaml -t True
Please refer to the original repo for additional flags, such as --gpus.
To train, run the following script:
python celle_main.py --base configs/celle.yaml -t True
Specify --gpus in the same format as VQGAN.
CELL-E contains the following options from dalle-pytorch.
-
ckpt_path: Resume previous CELL-E training. Saved model with state_dict -
vqgan_model_path: Saved protein image model (with state_dict) for protein image encoder -
vqgan_config_path: Saved protein image model yaml -
condition_model_path: (Optional) Saved condition (nucleus) model (with state_dict) for protein image encoder -
condition_config_path: (Optional) Saved condition (nucleus) model yaml -
num_images: 1 if only using protein image encoder, 2 if including condition image encoder -
image_key:nucleus,target, orthreshold -
dim: Dimension of language model embedding (768 for BERT) -
num_text_tokens: total number of tokens in language model (30 for BERT) -
text_seq_len: Total number of amino acids considered -
depth: Transformer model depth, deeper is usually better at the cost of VRAM -
heads: number of heads used in multi-headed attention -
dim_head: size of attention heads -
reversible: See https://github.com/lucidrains/DALLE-pytorch#scaling-depth -
attn_dropout: Attention Dropout rate in training -
ff_dropout: Feed-Forward Dropout rate in training -
attn_types: See https://github.com/lucidrains/DALLE-pytorch#sparse-attention. Sparse attention not supported -
loss_img_weight: Weighting applied to image reconstruction. text weight = 1 -
loss_cond_weight: Weighting applied to condition image reconstruction. -
stable: Norms weights (for when exploding gradients occur) -
sandwich_norm: See https://github.com/lucidrains/x-transformers#sandwich-norm -
shift_tokens: Applies shift in feature dimension. Only applied to images. -
rotary_emb: Rotary embedding scheme for positional encoding -
text_embedding: language used for model.no_text,unirep,bert,esm1b,onehot,aadescriptorsavailable -
fixed_embedding: Setting toTrueallows for protein sequence embeddings to be updated during training -
learning_rate: Learning rate for Adam optimizer -
monitor: Param used to save models
To generate images, set the saved model as the ckpt_path. This method can be unstable, so refer to Demo.ipynb to see another way of loading.
import OmegaConf
from celle_main import instantiate_from_config
configs = OmegaConf.load(configs/celle.yaml);
model = instantiate_from_config(configs.model).to(device);
model.generate_images(text=sequence,
condition=nucleus,
return_logits=True,
progress=True,
use_cache=True)Please cite us if you decide to use our code for any part of your research.
CELL-E: Biological Zero-Shot Text-to-Image Synthesis for Protein Localization Prediction
Emaad Khwaja, Yun S. Song, Bo Huang
bioRxiv 2022.05.27.493774; doi: https://doi.org/10.1101/2022.05.27.493774
Huge shoutout to @lucidrains for putting out dalle-pytorch, which this code is based on. This work would not have been possible without the invaluable contribution.


