-
📁 Set up paths
Defined all directory and file paths for images, masks, and CSVs. -
🗂️ Copy images for classification
Wrote a helper function to copy training and testing images into dedicated folders for classification. -
🏷️ Create custom Dataset class
ImplementedIDRiDClassificationDataset
to load images and labels from CSVs. -
🖼️ Define image transforms
Used torchvision transforms for data augmentation (train) and normalization (train/test). -
🗺️ Map labels and create datasets
Createdtrain_dataset
andtest_dataset
using the custom dataset class and transforms. -
📦 Create DataLoaders
Wrapped datasets in PyTorch DataLoaders for batching and shuffling. -
🏗️ Build the model
Defined a function to build a ResNet18 classifier with a custom output layer for 5 classes. -
⚖️ Compute class weights
Usedsklearn
to calculate class weights for handling class imbalance. -
🔁 Training loop
Implemented a training loop with early stopping, label smoothing, and class weights. -
💾 Save the best model
Saved the trained model weights to disk. -
🧪 Evaluate on test set
Evaluated the model on the test set and computed accuracy. -
📊 Plot metrics
Plotted training/validation loss and accuracy curves. -
🧮 Confusion matrix
Plotted a confusion matrix to visualize prediction performance across classes. -
📝 Classification report
Printed a detailed classification report (precision, recall, F1-score) for each class. -
📄 Save predictions
Saved true and predicted labels for the test set to a CSV file. -
👁️ Visualize predictions
Displayed sample test images with predicted and true labels for qualitative inspection. -
🗒️ Log results
Logged model performance metrics and notes to a CSV file for future reference.
- 📁 Project Setup & Paths
- Define all dataset and directory paths for segmentation images, masks, grading images, and CSV labels.
-
🧪 Data Transformations
- Use
albumentations
for image augmentation and normalization for both training and testing.
- Use
-
📦 Multi-Task Dataset
- Implement a custom
IDRiDMultiTaskDataset
class: - Loads fundus images, segmentation masks (multi-channel), and disease grading labels.
- Handles mask stacking for multiple lesions and applies transformations.
- Implement a custom
-
🔍 Data Exploration
- Visualize images and their corresponding lesion masks (both individual and combined overlays).
-
🏗️ Model Architecture
- Define
MultiTaskResNetWithRouting
: - Shared ResNet18 encoder.
- Routing layer (soft task gate) to dynamically weight classification and segmentation tasks.
- Separate expert heads for classification (disease grading) and segmentation (multi-channel lesion masks).
- Define
-
⚖️ Loss Functions
- Implement
MultiChannelDiceLoss
for segmentation. - Use
CrossEntropyLoss
(with class weights) for classification.
- Implement
-
💻 Device Setup
- Specify computation device (CPU/GPU).
-
📊 DataLoader Preparation
- Split dataset into training and validation sets.
- Create PyTorch DataLoaders for efficient batching.
-
🔁 Training Loop
- Train with early stopping:
- Jointly optimize classification and segmentation losses, weighted by α and β.
- Track and display training loss and accuracy.
-
🔍 Hyperparameter Search
- Run experiments over different α and β values.
- Compute class weights to handle label imbalance.
- Store and display experiment results in a live table.
-
💾 Model Saving
- Save the best-performing model checkpoint.
-
📈 Results Visualization
- Plot training loss and accuracy curves for each experiment.
- Visualize the effect of β on best accuracy.
-
📋 Summary Table
- Highlight and sort experiment results for easy comparison and Save summary as CSV.
-
🧪 Testing & Evaluation
- Load the best model.
- Prepare test dataset and DataLoader.
- Evaluate classification and segmentation performance (classification report, average Dice score).
-
🖼️ Qualitative Visualization
- Overlay predicted and ground truth lesion masks on fundus images for visual inspection.
-
🧮 Confusion Matrix
- Plot confusion matrix for classification results.
-
⚔️ Performance Comparison
- Compare single-task and multi-task model performance using classification reports.