This project explores Federated Learning (FL), a decentralized machine learning approach that enables multiple clients to collaboratively train a global model without sharing their raw data. The primary focus is on handling non-IID data distributions across clients and improving personalization at the local level.
Traditional centralized training methods require collecting data from all users into a central server, raising privacy concerns and scalability issues. Federated Learning addresses this by keeping data local and only sharing model updates. However, FL faces challenges when clients have non-identical data distributions (non-IID), which often leads to poor model generalization for some clients. This project addresses this challenge by incorporating a personalization phase where each client fine-tunes the global model on their own data post-aggregation.
The system simulates a FL environment using the MNIST dataset with artificially created non-IID splits among clients. Each client trains a shared CNN model locally and contributes to the global model through federated averaging. The training loop includes:
- Random client participation per round
- Local training using different learning rates and local epochs per client
- Aggregation of client weights into a new global model
- Local personalization: clients fine-tune the global model on their own data
- Evaluation of both global and personalized models per client
In addition, performance metrics such as global accuracy, validation loss, per-client accuracy, and fairness indicators (accuracy standard deviation) are tracked and visualized through a Streamlit dashboard.
- clients.py: Contains the Client class, which manages local training, evaluation, and personalization logic for each client.
- server.py: Defines the Server class, responsible for aggregating model updates and evaluating the global model.
- data.py: Includes data loading utilities and non-IID data splitting logic for the MNIST dataset.
- models.py: Holds the CNN architecture used by both the server and clients.
- results.py: Stores all training outputs, including CSV files of metrics and generated plots used by the dashboard.
- dashboard.py: A fully interactive Streamlit dashboard that visualizes training performance, fairness trends, and per-client personalization improvements.
- main.py: The central script that runs the full federated training process, logging results to disk.
The training process was run over multiple federated rounds, and the following trends were observed:
- The global model improved consistently across rounds, achieving strong validation accuracy on centralized test data.
- Personalization significantly enhanced the performance of clients with diverse data distributions, especially those underperforming in global-only evaluation.
- The fairness plot (standard deviation of accuracy across clients) revealed how personalization helped reduce performance gaps.
- A comparison bar chart showed the advantage of personalized models over the global model per client in the final round.
All results are visualized in the included dashboard and stored as reusable CSVs and PNGs in the results/
directory.
This project demonstrates the effectiveness of personalization in Federated Learning environments, especially under non-IID data conditions. It also provides a reproducible framework for simulating FL systems, tracking detailed performance metrics, and analyzing fairness across clients using a well-organized dashboard.