Skip to content

sys0507/tcr-epitope-generation

Repository files navigation

TCR Epitope Generation Model

A deep learning model for generating epitope sequences from T-cell receptor (TCR) features using cross-attention mechanisms and protein language models.

Overview

This model combines a frozen TCR encoder with a trainable epitope decoder to predict peptide epitopes that TCRs may recognize. It leverages:

  • ESM2-t6-8M protein language model for sequence embeddings (320D)
  • Cross-attention architecture for TCR-epitope feature integration
  • Autoregressive transformer decoder for sequence generation
  • Gene embeddings (TRAV, TRBV, TRAJ, TRBJ)
  • MHC allele embeddings for pMHC context

Model Architecture

Input Features:
├── TRA CDR3 sequence → ESM2 embeddings (320D)
├── TRB CDR3 sequence → ESM2 embeddings (320D)
├── TRAV gene → Embedding (32D)
├── TRBV gene → Embedding (32D)
├── TRAJ gene → Embedding (32D)
├── TRBJ gene → Embedding (32D)
└── MHC allele → Embedding (64D)

TCR Encoder (Frozen):
└── Cross-attention + Feedforward layers → TCR representation (256D)

Epitope Decoder (Trainable):
├── Autoregressive Transformer (6 layers, 8 heads)
├── Cross-attention to TCR representation
└── Token prediction → Epitope sequence

Performance Metrics

Top-K Retrieval Evaluation (Validation Set)

Beam Search (beam_size=50):

  • Recall@1: 52.47%
  • Recall@5: 52.47%
  • Recall@10: 52.47%
  • MRR: 0.5247
  • Coverage: 510/972 samples

Sampling (T=1.5, 50 samples):

  • Recall@1: 32.61%
  • Recall@5: 48.46% (+15.8%)
  • Recall@10: 51.34% (+18.7%)
  • Recall@50: 53.50% (+20.9%)
  • MRR: 0.3868
  • Coverage: 520/972 samples

Greedy Decoding:

  • Recall@1: 4.63%
  • Coverage: 223/972 samples

Key Findings

  • Beam collapse: The model exhibits high confidence, with beam search generating only 1 unique sequence per sample
  • Best for retrieval: Sampling with temperature=1.5 provides the best diversity and top-K coverage
  • Challenge: 47-52% of validation samples have no matching epitope in top-50 predictions

Installation

Requirements

pip install torch>=2.0.0
pip install transformers>=4.30.0
pip install pandas
pip install scikit-learn
pip install matplotlib
pip install seaborn

Model Files

Large model checkpoints are tracked with Git LFS:

  • model/tcr_epitope_generator_best.pt (82 MB) - Full generator checkpoint
  • model/best_tcr_epitope_mhc_matcher.pth (26 MB) - TCR encoder checkpoint

Make sure to install Git LFS before cloning:

git lfs install
git clone <repository-url>

Usage

Basic Prediction

from tcr_epitope_generator import TCREpitopeGenerator
import torch

# Load model
checkpoint = torch.load('model/tcr_epitope_generator_best.pt')
generator = TCREpitopeGenerator(...)
generator.load_state_dict(checkpoint['model_state_dict'])

# Prepare TCR input
tcr_batch = {
    'tra_cdr3_ids': torch.tensor([[...]]),
    'trb_cdr3_ids': torch.tensor([[...]]),
    'trav': torch.tensor([14]),
    'trbv': torch.tensor([29]),
    'traj': torch.tensor([5]),
    'trbj': torch.tensor([2]),
    'mhc_allele': torch.tensor([10])
}

# Generate single epitope (greedy)
epitope = generator.generate(tcr_batch, method='greedy')
print(f"Predicted epitope: {epitope}")

Top-K Prediction with Probabilities

# Generate top-5 epitopes with beam search
sequences, probs, log_probs, metadata = generator.generate(
    tcr_batch,
    method='beam_search',
    return_top_k=5,
    return_scores=True,
    beam_size=50,
    length_penalty=1.0
)

for rank, (seq, prob) in enumerate(zip(sequences, probs), 1):
    print(f"{rank}. {seq:15s} - {prob*100:.2f}%")

# Generate diverse samples with temperature sampling
sequences, probs, log_probs, metadata = generator.generate(
    tcr_batch,
    method='sampling',
    return_top_k=10,
    return_scores=True,
    temperature=1.5,
    num_samples=50
)

Generation Methods

1. Greedy Decoding (Fastest, Deterministic)

epitope = generator.generate(tcr_batch, method='greedy')

2. Beam Search (Balanced, Finds Global Optimum)

sequences, probs, _, _ = generator.generate(
    tcr_batch,
    method='beam_search',
    return_top_k=5,
    return_scores=True,
    beam_size=50,
    length_penalty=1.0
)

3. Sampling (Most Diverse, Exploration)

sequences, probs, _, _ = generator.generate(
    tcr_batch,
    method='sampling',
    return_top_k=10,
    return_scores=True,
    temperature=1.5,  # Higher = more diverse
    num_samples=50
)

Training

The model was trained using:

  • Dataset: 146 pMHC complexes with TCR sequences
  • Loss: Cross-entropy with teacher forcing
  • Optimizer: AdamW
  • Training strategy: Frozen encoder + trainable decoder
  • Epochs: 50 with early stopping

See the main notebook for full training pipeline:

TCR_model_modeling - Epitope_MHC_MetricLearning_epitopeGen - CrossAttention - fundation_ESM2_8M.ipynb

Evaluation

The notebook includes comprehensive evaluation metrics:

Section 7.7-7.8: Sequence-level metrics

  • Exact match accuracy
  • BLEU scores
  • Edit distance
  • Per-position accuracy

Section 7.9: Top-K retrieval metrics

  • Recall@K (k=1,3,5,10,20,30,50)
  • Mean Reciprocal Rank (MRR)
  • Rank distribution
  • Coverage analysis

Files

  • tcr_epitope_generator.py - Main generator class with top-k prediction
  • epitope_decoder.py - Autoregressive transformer decoder
  • amino_acid_tokenizer.py - Amino acid tokenization utilities
  • comprehensive_evaluation_metrics.py - Full evaluation suite
  • data/full_training_set_146pmhc_For_classification.csv - Training data
  • model/ - Trained model checkpoints (Git LFS)

Citation

If you use this model in your research, please cite:

@software{tcr_epitope_generator,
  title = {TCR Epitope Generation Model},
  author = {Your Name},
  year = {2024},
  url = {https://github.com/yourusername/tcr-epitope-generation}
}

License

[Specify your license here]

Acknowledgments

  • ESM2 protein language model by Meta AI
  • Built with PyTorch and HuggingFace Transformers