Skip to content

This project simulates a Federated Learning system using non-IID MNIST data across multiple clients, focusing on collaborative training without data sharing. It tracks performance metrics like global accuracy, validation loss, and fairness while allowing each client to personalize the global model locally.

Notifications You must be signed in to change notification settings

daaven100/Federated-Learning-with-Model-Personalization

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Federated Learning with Model Personalization

Introduction

This project explores Federated Learning (FL), a decentralized machine learning approach that enables multiple clients to collaboratively train a global model without sharing their raw data. The primary focus is on handling non-IID data distributions across clients and improving personalization at the local level.

Background

Traditional centralized training methods require collecting data from all users into a central server, raising privacy concerns and scalability issues. Federated Learning addresses this by keeping data local and only sharing model updates. However, FL faces challenges when clients have non-identical data distributions (non-IID), which often leads to poor model generalization for some clients. This project addresses this challenge by incorporating a personalization phase where each client fine-tunes the global model on their own data post-aggregation.

Project Description

The system simulates a FL environment using the MNIST dataset with artificially created non-IID splits among clients. Each client trains a shared CNN model locally and contributes to the global model through federated averaging. The training loop includes:

  • Random client participation per round
  • Local training using different learning rates and local epochs per client
  • Aggregation of client weights into a new global model
  • Local personalization: clients fine-tune the global model on their own data
  • Evaluation of both global and personalized models per client

In addition, performance metrics such as global accuracy, validation loss, per-client accuracy, and fairness indicators (accuracy standard deviation) are tracked and visualized through a Streamlit dashboard.

Folder Descriptions

  • clients.py: Contains the Client class, which manages local training, evaluation, and personalization logic for each client.
  • server.py: Defines the Server class, responsible for aggregating model updates and evaluating the global model.
  • data.py: Includes data loading utilities and non-IID data splitting logic for the MNIST dataset.
  • models.py: Holds the CNN architecture used by both the server and clients.
  • results.py: Stores all training outputs, including CSV files of metrics and generated plots used by the dashboard.
  • dashboard.py: A fully interactive Streamlit dashboard that visualizes training performance, fairness trends, and per-client personalization improvements.
  • main.py: The central script that runs the full federated training process, logging results to disk.

Results

The training process was run over multiple federated rounds, and the following trends were observed:

  • The global model improved consistently across rounds, achieving strong validation accuracy on centralized test data.
  • Personalization significantly enhanced the performance of clients with diverse data distributions, especially those underperforming in global-only evaluation.
  • The fairness plot (standard deviation of accuracy across clients) revealed how personalization helped reduce performance gaps.
  • A comparison bar chart showed the advantage of personalized models over the global model per client in the final round.

All results are visualized in the included dashboard and stored as reusable CSVs and PNGs in the results/ directory.

Conclusion

This project demonstrates the effectiveness of personalization in Federated Learning environments, especially under non-IID data conditions. It also provides a reproducible framework for simulating FL systems, tracking detailed performance metrics, and analyzing fairness across clients using a well-organized dashboard.

About

This project simulates a Federated Learning system using non-IID MNIST data across multiple clients, focusing on collaborative training without data sharing. It tracks performance metrics like global accuracy, validation loss, and fairness while allowing each client to personalize the global model locally.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages