This repository contains code for training and evaluating a gated sensor fusion neural network for GPS position estimation. The project leverages PyTorch, Weights & Biases (wandb), and custom data loaders for handling sensor data.
- Custom PyTorch datasets for loading and batching sensor fusion data
- Gated sensor fusion neural network model
- Training and validation loops with early stopping
- Inference pipeline for evaluating model performance
- Integration with Weights & Biases for experiment tracking
dataloaders.py
: Custom PyTorchDataset
classes for training and inferencegated_fusion_net.py
: Implementation of the Gated Sensor Fusion neural networktrain_loop.py
: Training script with argument parsing, early stopping, and wandb loggingInfererence.py
: Inference script for evaluating the model on test datautils.py
: Utility functions (e.g., plotting, reading folder names)lists/
: Text files listing data folders for training, validation, and testing
-
Install dependencies:
pip install -r requirements.txt
Ensure you have
torch
,pandas
,numpy
,wandb
,python-dotenv
, and other required packages. -
Set up Weights & Biases:
- Create a
.env
file with your WANDB API key:WANDB_API_KEY=your_wandb_api_key_here
- Create a
-
Prepare data:
- Place your data folders and CSV files as referenced in the
lists/
directory.
- Place your data folders and CSV files as referenced in the
Run the training loop with default or custom arguments:
python train_loop.py --train_list lists/train.txt --val_list lists/val.txt --seen_data_dir ./seen_subjects_test_set --output_dir ./no_drp_out_checkpoint
Run the inference script to evaluate the trained model:
python Infererence.py
- The model expects CSV files named
synced_gps_gt_roin.csv
in each data folder. - Training and validation folder names should be listed in the corresponding text files in the
lists/
directory. - Results and logs are tracked using wandb.
This project is for educational and research purposes.