A graph neural network implementation for predicting molecular properties from the QM9 dataset, enhanced with Topological Polar Surface Area (TPSA) as an additional target property. The model leverages PyTorch Geometric and supports multiple graph convolutional architectures.
- Features
- Installation
- Quick Start
- Configuration
- Usage
- Model Architecture
- Dataset
- Results
- Contributing
- License
- Acknowledgements
- Citation
- Multi-target Property Prediction: Predict multiple molecular properties simultaneously
- TPSA Integration: Automatically calculates and includes Topological Polar Surface Area as an additional target
- Flexible Architecture: Support for various GNN layers (GCN, GAT, SAGE, GIN, GraphConv, TransformerConv)
- Comprehensive Training Pipeline: Includes early stopping, learning rate scheduling, and regularization
- Extensible Design: Modular codebase for easy experimentation and extension
- Python 3.8 or higher
- CUDA-compatible GPU (optional but recommended)
-
Clone the repository:
git clone https://github.com/shahram-boshra/qm9_tpsa.git cd tpsa-augmented-gnn
-
Create a virtual environment:
python -m venv venv source venv/bin/activate # On Windows: venv\Scripts\activate
-
Install PyTorch and PyTorch Geometric:
# For CUDA 11.8 (adjust according to your setup) pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # Install PyTorch Geometric pip install torch-geometric
-
Install remaining dependencies:
pip install -r requirements.txt
Create a requirements.txt
file with the following content:
rdkit-pypi>=2022.9.1
scikit-learn>=1.0.2
matplotlib>=3.5.0
pyyaml>=6.0
numpy>=1.21.0
pandas>=1.3.0
tqdm>=4.62.0
-
Configure the project:
cp config.yaml.example config.yaml # Edit config.yaml with your desired settings
-
Run training:
python main.py
-
Monitor training progress: The training script will output progress to the console and save model checkpoints automatically.
The project uses a YAML configuration file for easy parameter management:
data:
root_dir: "./data/qm9" # Root directory for the QM9 dataset
target_indices: [0, 3, 5] # Indices of target properties to predict
use_cache: true # Cache processed data for faster loading
train_split: 0.8 # Training set proportion
valid_split: 0.1 # Validation set proportion
subset_size: null # Optional: Limit dataset size for testing
model:
batch_size: 256
learning_rate: 0.007
weight_decay: 1.09e-05
step_size: 50 # LR scheduler step size
gamma: 0.5 # LR scheduler decay factor
reduce_lr_factor: 0.5 # Factor for ReduceLROnPlateau
reduce_lr_patience: 10 # Patience for ReduceLROnPlateau
early_stopping_patience: 20 # Early stopping patience
early_stopping_delta: 0.001 # Minimum improvement threshold
l1_regularization_lambda: 0.006 # L1 regularization strength
first_layer_type: "transformer_conv" # First GNN layer type
hidden_channels: 512 # Hidden dimension size
second_layer_type: "transformer_conv" # Second GNN layer type
dropout_rate: 0.176 # Dropout rate
The QM9 dataset includes 19 molecular properties. Common target indices include:
0
: Dipole moment1
: Isotropic polarizability2
: HOMO energy3
: LUMO energy4
: HOMO-LUMO gap5
: Electronic spatial extent
from config_loader import load_config
from dataset import QM9Dataset
from models import MGModel
from training_utils import Trainer
# Load configuration
config = load_config('config.yaml')
# Initialize dataset
dataset = QM9Dataset(config['data'])
# Create model
model = MGModel(config['model'])
# Train model
trainer = Trainer(model, dataset, config)
trainer.train()
# Example: Using different GNN layers
config['model']['first_layer_type'] = 'gcn'
config['model']['second_layer_type'] = 'gat'
config['model']['hidden_channels'] = 256
The MGModel
class implements a flexible GNN architecture with:
- Input Processing: Node and edge feature encoding
- Graph Convolution Layers: Support for multiple GNN architectures
- Global Pooling: Graph-level representation learning
- Output Head: Multi-target property prediction including TPSA
- GCN: Graph Convolutional Network
- GAT: Graph Attention Network
- SAGE: GraphSAGE
- GIN: Graph Isomorphism Network
- GraphConv: Graph Convolutional Layer
- TransformerConv: Graph Transformer
The model uses the QM9 dataset, which contains:
- 130,831 small organic molecules
- Up to 9 heavy atoms (C, N, O, F)
- 19 molecular properties computed using DFT
- Additional TPSA calculation using RDKit
The dataset is automatically downloaded and processed on first run.
Training metrics and model performance are automatically logged and visualized. The framework tracks:
- Training and validation loss
- Mean Absolute Error (MAE) for each target property
- Learning rate schedules
- Early stopping criteria
We welcome contributions! Please follow these steps:
- Fork the repository
- Create a feature branch:
git checkout -b feature/amazing-feature
- Make your changes and add tests
- Commit your changes:
git commit -m 'Add some amazing feature'
- Push to the branch:
git push origin feature/amazing-feature
- Open a Pull Request
# Install development dependencies
pip install -r requirements-dev.txt
# Run tests
python -m pytest tests/
# Run linting
flake8 src/
black src/
This project is licensed under the MIT License - see the LICENSE file for details.
This project builds upon several excellent open-source libraries:
- PyTorch - Deep learning framework
- PyTorch Geometric - Graph neural network library
- RDKit - Cheminformatics toolkit
- QM9 Dataset - Molecular property dataset
- scikit-learn - Machine learning utilities
If you use this code in your research, please cite:
@software{tpsa_augmented_gnn,
title={TPSA-Augmented GNN for Molecular Property Prediction on QM9},
author={Your Name},
year={2025},
url={https://github.com/shahram-boshra/qm9_tpsa}
}
For questions or support, please:
- Open an issue on GitHub
- Contact: a.boshra@gmail.com
Note: Replace yourusername
, yourname@email.com
, and other placeholder information with your actual details before uploading to GitHub.