A deep learning model for generating epitope sequences from T-cell receptor (TCR) features using cross-attention mechanisms and protein language models.
This model combines a frozen TCR encoder with a trainable epitope decoder to predict peptide epitopes that TCRs may recognize. It leverages:
- ESM2-t6-8M protein language model for sequence embeddings (320D)
- Cross-attention architecture for TCR-epitope feature integration
- Autoregressive transformer decoder for sequence generation
- Gene embeddings (TRAV, TRBV, TRAJ, TRBJ)
- MHC allele embeddings for pMHC context
Input Features:
├── TRA CDR3 sequence → ESM2 embeddings (320D)
├── TRB CDR3 sequence → ESM2 embeddings (320D)
├── TRAV gene → Embedding (32D)
├── TRBV gene → Embedding (32D)
├── TRAJ gene → Embedding (32D)
├── TRBJ gene → Embedding (32D)
└── MHC allele → Embedding (64D)
TCR Encoder (Frozen):
└── Cross-attention + Feedforward layers → TCR representation (256D)
Epitope Decoder (Trainable):
├── Autoregressive Transformer (6 layers, 8 heads)
├── Cross-attention to TCR representation
└── Token prediction → Epitope sequence
Beam Search (beam_size=50):
- Recall@1: 52.47%
- Recall@5: 52.47%
- Recall@10: 52.47%
- MRR: 0.5247
- Coverage: 510/972 samples
Sampling (T=1.5, 50 samples):
- Recall@1: 32.61%
- Recall@5: 48.46% (+15.8%)
- Recall@10: 51.34% (+18.7%)
- Recall@50: 53.50% (+20.9%)
- MRR: 0.3868
- Coverage: 520/972 samples
Greedy Decoding:
- Recall@1: 4.63%
- Coverage: 223/972 samples
- Beam collapse: The model exhibits high confidence, with beam search generating only 1 unique sequence per sample
- Best for retrieval: Sampling with temperature=1.5 provides the best diversity and top-K coverage
- Challenge: 47-52% of validation samples have no matching epitope in top-50 predictions
pip install torch>=2.0.0
pip install transformers>=4.30.0
pip install pandas
pip install scikit-learn
pip install matplotlib
pip install seabornLarge model checkpoints are tracked with Git LFS:
model/tcr_epitope_generator_best.pt(82 MB) - Full generator checkpointmodel/best_tcr_epitope_mhc_matcher.pth(26 MB) - TCR encoder checkpoint
Make sure to install Git LFS before cloning:
git lfs install
git clone <repository-url>from tcr_epitope_generator import TCREpitopeGenerator
import torch
# Load model
checkpoint = torch.load('model/tcr_epitope_generator_best.pt')
generator = TCREpitopeGenerator(...)
generator.load_state_dict(checkpoint['model_state_dict'])
# Prepare TCR input
tcr_batch = {
'tra_cdr3_ids': torch.tensor([[...]]),
'trb_cdr3_ids': torch.tensor([[...]]),
'trav': torch.tensor([14]),
'trbv': torch.tensor([29]),
'traj': torch.tensor([5]),
'trbj': torch.tensor([2]),
'mhc_allele': torch.tensor([10])
}
# Generate single epitope (greedy)
epitope = generator.generate(tcr_batch, method='greedy')
print(f"Predicted epitope: {epitope}")# Generate top-5 epitopes with beam search
sequences, probs, log_probs, metadata = generator.generate(
tcr_batch,
method='beam_search',
return_top_k=5,
return_scores=True,
beam_size=50,
length_penalty=1.0
)
for rank, (seq, prob) in enumerate(zip(sequences, probs), 1):
print(f"{rank}. {seq:15s} - {prob*100:.2f}%")
# Generate diverse samples with temperature sampling
sequences, probs, log_probs, metadata = generator.generate(
tcr_batch,
method='sampling',
return_top_k=10,
return_scores=True,
temperature=1.5,
num_samples=50
)1. Greedy Decoding (Fastest, Deterministic)
epitope = generator.generate(tcr_batch, method='greedy')2. Beam Search (Balanced, Finds Global Optimum)
sequences, probs, _, _ = generator.generate(
tcr_batch,
method='beam_search',
return_top_k=5,
return_scores=True,
beam_size=50,
length_penalty=1.0
)3. Sampling (Most Diverse, Exploration)
sequences, probs, _, _ = generator.generate(
tcr_batch,
method='sampling',
return_top_k=10,
return_scores=True,
temperature=1.5, # Higher = more diverse
num_samples=50
)The model was trained using:
- Dataset: 146 pMHC complexes with TCR sequences
- Loss: Cross-entropy with teacher forcing
- Optimizer: AdamW
- Training strategy: Frozen encoder + trainable decoder
- Epochs: 50 with early stopping
See the main notebook for full training pipeline:
TCR_model_modeling - Epitope_MHC_MetricLearning_epitopeGen - CrossAttention - fundation_ESM2_8M.ipynb
The notebook includes comprehensive evaluation metrics:
Section 7.7-7.8: Sequence-level metrics
- Exact match accuracy
- BLEU scores
- Edit distance
- Per-position accuracy
Section 7.9: Top-K retrieval metrics
- Recall@K (k=1,3,5,10,20,30,50)
- Mean Reciprocal Rank (MRR)
- Rank distribution
- Coverage analysis
tcr_epitope_generator.py- Main generator class with top-k predictionepitope_decoder.py- Autoregressive transformer decoderamino_acid_tokenizer.py- Amino acid tokenization utilitiescomprehensive_evaluation_metrics.py- Full evaluation suitedata/full_training_set_146pmhc_For_classification.csv- Training datamodel/- Trained model checkpoints (Git LFS)
If you use this model in your research, please cite:
@software{tcr_epitope_generator,
title = {TCR Epitope Generation Model},
author = {Your Name},
year = {2024},
url = {https://github.com/yourusername/tcr-epitope-generation}
}[Specify your license here]
- ESM2 protein language model by Meta AI
- Built with PyTorch and HuggingFace Transformers