Skip to content

shahram-boshra/qm9_tpsa

Repository files navigation

TPSA-Augmented GNN for Molecular Property Prediction on QM9

Python PyTorch PyTorch Geometric License

A graph neural network implementation for predicting molecular properties from the QM9 dataset, enhanced with Topological Polar Surface Area (TPSA) as an additional target property. The model leverages PyTorch Geometric and supports multiple graph convolutional architectures.

πŸ“‹ Table of Contents

✨ Features

  • Multi-target Property Prediction: Predict multiple molecular properties simultaneously
  • TPSA Integration: Automatically calculates and includes Topological Polar Surface Area as an additional target
  • Flexible Architecture: Support for various GNN layers (GCN, GAT, SAGE, GIN, GraphConv, TransformerConv)
  • Comprehensive Training Pipeline: Includes early stopping, learning rate scheduling, and regularization
  • Extensible Design: Modular codebase for easy experimentation and extension

πŸš€ Installation

Prerequisites

  • Python 3.8 or higher
  • CUDA-compatible GPU (optional but recommended)

Install Dependencies

  1. Clone the repository:

    git clone https://github.com/shahram-boshra/qm9_tpsa.git
    cd tpsa-augmented-gnn
  2. Create a virtual environment:

    python -m venv venv
    source venv/bin/activate  # On Windows: venv\Scripts\activate
  3. Install PyTorch and PyTorch Geometric:

    # For CUDA 11.8 (adjust according to your setup)
    pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    
    # Install PyTorch Geometric
    pip install torch-geometric
  4. Install remaining dependencies:

    pip install -r requirements.txt

Requirements File

Create a requirements.txt file with the following content:

rdkit-pypi>=2022.9.1
scikit-learn>=1.0.2
matplotlib>=3.5.0
pyyaml>=6.0
numpy>=1.21.0
pandas>=1.3.0
tqdm>=4.62.0

🎯 Quick Start

  1. Configure the project:

    cp config.yaml.example config.yaml
    # Edit config.yaml with your desired settings
  2. Run training:

    python main.py
  3. Monitor training progress: The training script will output progress to the console and save model checkpoints automatically.

βš™οΈ Configuration

The project uses a YAML configuration file for easy parameter management:

data:
  root_dir: "./data/qm9"           # Root directory for the QM9 dataset
  target_indices: [0, 3, 5]       # Indices of target properties to predict
  use_cache: true                  # Cache processed data for faster loading
  train_split: 0.8                 # Training set proportion
  valid_split: 0.1                 # Validation set proportion
  subset_size: null                # Optional: Limit dataset size for testing

model:
  batch_size: 256
  learning_rate: 0.007
  weight_decay: 1.09e-05
  step_size: 50                    # LR scheduler step size
  gamma: 0.5                       # LR scheduler decay factor
  reduce_lr_factor: 0.5           # Factor for ReduceLROnPlateau
  reduce_lr_patience: 10          # Patience for ReduceLROnPlateau
  early_stopping_patience: 20     # Early stopping patience
  early_stopping_delta: 0.001     # Minimum improvement threshold
  l1_regularization_lambda: 0.006 # L1 regularization strength
  first_layer_type: "transformer_conv"  # First GNN layer type
  hidden_channels: 512            # Hidden dimension size
  second_layer_type: "transformer_conv" # Second GNN layer type
  dropout_rate: 0.176             # Dropout rate

Target Properties

The QM9 dataset includes 19 molecular properties. Common target indices include:

  • 0: Dipole moment
  • 1: Isotropic polarizability
  • 2: HOMO energy
  • 3: LUMO energy
  • 4: HOMO-LUMO gap
  • 5: Electronic spatial extent

πŸ“– Usage

Basic Training

from config_loader import load_config
from dataset import QM9Dataset
from models import MGModel
from training_utils import Trainer

# Load configuration
config = load_config('config.yaml')

# Initialize dataset
dataset = QM9Dataset(config['data'])

# Create model
model = MGModel(config['model'])

# Train model
trainer = Trainer(model, dataset, config)
trainer.train()

Custom Model Configuration

# Example: Using different GNN layers
config['model']['first_layer_type'] = 'gcn'
config['model']['second_layer_type'] = 'gat'
config['model']['hidden_channels'] = 256

πŸ—οΈ Model Architecture

The MGModel class implements a flexible GNN architecture with:

  1. Input Processing: Node and edge feature encoding
  2. Graph Convolution Layers: Support for multiple GNN architectures
  3. Global Pooling: Graph-level representation learning
  4. Output Head: Multi-target property prediction including TPSA

Supported GNN Layers

  • GCN: Graph Convolutional Network
  • GAT: Graph Attention Network
  • SAGE: GraphSAGE
  • GIN: Graph Isomorphism Network
  • GraphConv: Graph Convolutional Layer
  • TransformerConv: Graph Transformer

πŸ“Š Dataset

The model uses the QM9 dataset, which contains:

  • 130,831 small organic molecules
  • Up to 9 heavy atoms (C, N, O, F)
  • 19 molecular properties computed using DFT
  • Additional TPSA calculation using RDKit

The dataset is automatically downloaded and processed on first run.

πŸ“ˆ Results

Training metrics and model performance are automatically logged and visualized. The framework tracks:

  • Training and validation loss
  • Mean Absolute Error (MAE) for each target property
  • Learning rate schedules
  • Early stopping criteria

🀝 Contributing

We welcome contributions! Please follow these steps:

  1. Fork the repository
  2. Create a feature branch:
    git checkout -b feature/amazing-feature
  3. Make your changes and add tests
  4. Commit your changes:
    git commit -m 'Add some amazing feature'
  5. Push to the branch:
    git push origin feature/amazing-feature
  6. Open a Pull Request

Development Setup

# Install development dependencies
pip install -r requirements-dev.txt

# Run tests
python -m pytest tests/

# Run linting
flake8 src/
black src/

πŸ“ License

This project is licensed under the MIT License - see the LICENSE file for details.

πŸ™ Acknowledgements

This project builds upon several excellent open-source libraries:

πŸ“š Citation

If you use this code in your research, please cite:

@software{tpsa_augmented_gnn,
  title={TPSA-Augmented GNN for Molecular Property Prediction on QM9},
  author={Your Name},
  year={2025},
  url={https://github.com/shahram-boshra/qm9_tpsa}
}

πŸ“§ Contact

For questions or support, please:


Note: Replace yourusername, yourname@email.com, and other placeholder information with your actual details before uploading to GitHub.