This repository provides code and instructions to reproduce the results presented in our work on CellCLIP. The proposed framework aligns Cell Painting image embeddings with perturbation-level textual descriptions, enabling biologically meaningful representations for downstream retrieval and matching tasks.
├── src/ # Core source files
│ ├── clip/ # CellCLIP and contrastive learning modules
│ │ ├── method.py # Contrastive loss implementations (e.g., CWCL, CLOOB, InfoNCE)
│ │ └── model.py # CellCLIP and CrossChannelFormer model definitions
│ ├── helpler.py # Utility functions
│ └── ... # Other supporting modules
│
├── configs/ # Model and training configuration files
├── preprocessing/ # Files for preprocessing
│
├── main.py # Main training script
│
├── retrieval.py # Cross-modal retrieval evaluation
├── rxrx3-core_efaar_eval.py # Intra-modal evaluation on RxRx3-core (gene–gene recovery)
└── cpjump_matching_eval.py # Replicate detection and sister perturbation matching on CP-JUMP1
To reproduce the results, follow these steps:
- Environment setup and installation
- Preprocessing Cell Painting images and associated metadata
- Training the proposed model
- Evaluating cross-modal and intra-modal retrieval performance
Set up a virtual environment with Python 3.11.5. Before starting, ensure all required packages are installed:
pip install -r requirements.txt
Create a src/constants.py
file with the following content:
DATASET_DIR = "dataset_dir"
OUTDIR = "model_out_dir"
LOGDIR = "log_dir"
Add the repo directory to PYTHONPATH:
export PYTHONPATH="$PYTHONPATH:$PWD"
Please first download Cell Painting images and corresponding metadata and labels from each link.
-
Bray2017 Preprocessed Data Available at: https://ml.jku.at/software/cellpainting/dataset/
-
RxRx3-Core Download from RxRx3-core at Hugging Face
-
CP-JUMP1 Available via instruction from the official repository
To normalize raw Cell Painting image values into the [0-255], use:
python preprocessing/preprocess_images.py
Once the preprocessed images are ready, you can generate embeddings using our proposed CrossChannelFormer encoding scheme by running:
python preprocessing/convert_npz_to_avg_emb.py \
--model_card facebook/dino-vitb8 \ # Feature extractor to generate embeddings
--dataset bray2017 \
--input_dir path_to_dataset \
--aggregation_strategy mean \ # Aggregation method (e.g., mean, attention)
--n_crop 1 \ # Number of crops per image
--output_file dino-vitb8_ind.h5 # Output file path
To generate molecule-level prompts or fingerprints for contrastive training:
python preprocessing/preprocess_molecules.py \
--dataset [bray2017 | jumpcp | rxrx3-core] \
--output_file output_filename.h5|csv \
--img_dir /path/to/input_data
To train CellCLIP, execute the following command:
python main.py
# === Dataset and Input Files ===
--dataset [bray2017 | jumpcp] \ # Dataset name
--img_dir /path/to/image_embeddings or images \ # Directory containing image embeddings or images in step2
# === Image Preprocessing (Optional) ===
--image_resolution_train 224 \ # Resolution of training image inputs
--image_resolution_val 224 \ # Resolution of validation image inputs
--molecule_path /path/to/perturbation_descriptions \ # Path to molecule or text input in step 2
--unique # Whether to treat perturbations as unique (multi-instance mode)
# === Model Configuration ===
--model_type [milcellclip | cloome | molphenix] \ # Type of model architecture
--input_dim [768 | 1024 | 1536] \ # Input feature dimensionality (depends on embedding source)
--loss [cwcl | clip | cloob] \ # Contrastive loss function
# === Optimization Hyperparameters ===
--epochs 50 \ # Number of training epochs
--batch_size 512 \ # Batch size
# === Learning Rate and Scheduler ===
--lr 5e-4 \ # Learning rate
--lr_scheduler [cosine | const | const-cooldown] \ # LR scheduler type
--warmup 1000 \ # Number of warmup steps
--num_cycles 5 \ # Number of cosine cycles for LR scheduler
# === Checkpointing & Logging ===
--ckpt_freq 1000 \ # Frequency (in steps) to save checkpoints
--keep_all_ckpts \ # Save all checkpoints (not just latest)
--log_freq 20 \ # Log every N steps
--eval_freq 500 # Evaluate every N steps
To enable distributed training across multiple GPUs, use accelerate
:
accelerate launch --config_file configs/your_config.yaml main.py ...
Note: On a setup with 8 × RTX 6000 GPUs, a maximum batch size of 512 has been tested successfully. Below is an example command to train CellCLIP using accelerate
.
accelerate launch \
--config_file configs/ddp_config.yaml main.py \
--split 1 \
--is_train \
--resume \
--batch_size 512 \
--epochs 50 \
--model_type mil_cell_clip \
--input_dim 1536 \
--dataset bray2017 \
--img_dir path_to_embeddings \
--unique \
--molecule_path path_to_molecules \
--loss_type cloob \
--lr_scheduler cosine-restarts \
--num_cycles 4 \
--wd 0.1 \
--init-inv-tau 14.3 \
--learnable-inv-tau \
--warmup 1000 \
--ckpt_freq 500 \
--eval_freq 100 \
--opt_seed 42 \
--lr 0.0001
This section describes how to evaluate the trained model on both cross-modal and intra-modal tasks.
Evaluate the alignment between Cell Painting images and perturbation-level text embeddings:
python retrieval.py \
--embedding_type /path/to/eval_embeddings \ # Path to aggregated embeddings
--model_type [milcellclip | cloome | molphenix] \ # Model architecture
--input_dim [768 | 1024 | 1536] \ # Embedding dimensionality
--loss [cwcl | clip | cloob] \ # Loss used during training
--ckpt_path /path/to/trained_model.pt \ # Path to model checkpoint
--unique \ # Use multi-instance mode if applicable
--image_resolution_train 224 \ # Resolution used for training
--image_resolution_val 224 # Resolution used for evaluation
For models trained on individual instances and evaluated on pooled profiles, use retrieval_whole.py
.
Use the following script to generate instance-level embeddings for RxRx3-core evaluation:
python preprocessing/convert_emb_to_ind_rxrx3core_emb.py \
--ckpt_path /path/to/trained_model.pt \
--model_type milcellclip \
--loss_type cwcl \
--input_dim 1536 \
--output_file output_embeddings.npz \
--img_dir /path/to/test_embeddings
Run zero-shot recovery of gene–gene relationships evaluation on RxRx3-Core
python rxrx3-core_efaar_eval.py --filepath [path_to_precomputed embeddings from a trained model, e.g., CellCLIP]
This evaluation tests the model’s ability to:
- Detect biological replicates (same perturbation, different images)
- Match sister perturbations that target the same biological pathway
Use the following script to generate instance-level embeddings from a trained model:
python preprocessing/convert_emb_to_cellclip_emb.py
--ckpt_path [path to trained CellCLIP ckpt]
--model_type [mil_cell_clip]
--loss_type cwcl
--input_dim 1536
--pretrained_emb name of the pretrained embeddings
--img_dir path_to_testing_data_embeddings
Run CP-JUMP1 Evaluation
python cpjump1_matching_eval.py \
--kernel poly \ # Kernel for batch correction (e.g., poly)
--feature_type [profile | emb] \ # Whether to use raw profiles or embeddings
--batch_correction # Enable batch effect correction