Skip to content

ODAM: A data augmentation model to reproduce optical data (e.g., spectra) and mitigate time-consuming simulation process

Notifications You must be signed in to change notification settings

JYJiahaoYan/Optical_Data_Augmentation_Model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Data augmentation model

Overview

This project is a collaboration with Jilong Yi, a talented graduate student.

The goal of the project is to achieve high-accuracy forward prediction from images or/and structural parameters to spectral data, with the aim of augmenting the database and saving substantial time otherwise spent on FDTD simulations.

Intro of Dataset

Images

图片1图片2图片3图片4

图片5 图片6 图片7

Parameters

Shape Params
circle a,b,φ,Px,Py
rec W,L,φ,Px,Py
double_rec W1,L1,W2,L2,φ,Px,Py
double_ellipse a,b,θ,φ,Px,Py
ring R,r,θ,φ,Px,Py
cross W1,L1,W2,L2,offset,φ,Px,Py
lack_rec W,L,α,β,γ,φ,Px,Py

替代文字

The image size is 900px*900px with a white background, where yellow represents the period size and red represents the structure size.

Structural Parameters

The input parameters consist of 14 parameters for different structures, among which the three parameters Px, Py, and Phi are common attributes shared by all different structures. For the remaining attributes, they should be filled in if they exist, and set to 0 if they do not, as shown in the table below:

W1 L1 W2 L2 L W alpha beta gama A B phi Px Py
100.0 104.0 214.0 219.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 72 356 421
...... ...... ...... ...... ...... ...... ...... ...... ...... ...... ...... ...... ...... ......

Spectra

The spectral data is generated by FDTD simulation, covering the wavelength range of 400~1800nm with 500 data points. The indices of the spectrum can be referenced in data/light_source.txt.

Code Structure

  • dataset.py: defines multiple PyTorch Dataset classes (ResnetData, ResnetData2, MLP_Data, MixData) to handle different types of data (e.g., image, spectral, periodic data, CSV data) with preprocessing (e.g., scaling) and train/val/test splitting, along with corresponding functions to generate DataLoaders for model training.
  • loss.py: defines two custom loss functions for spectral data in PyTorch, where CustomLoss combines MSE with peak error penalties and CustomLoss_2 splits the spectrum to apply different loss calculations (MSE + max value deviation).
  • model.py: defines a variety of PyTorch neural network architectures, including modified ResNets, RNNs, MLPs, and hybrid models (e.g., cross_model, MIXModel), designed to handle diverse inputs like images, periodic data, and parameters.
  • plotting_utils.py: contains functions for visualizing and analyzing predicted vs. raw data (e.g., parameters, CIE color values) using various plots (scatter plots, boxplots, histograms, etc.) and metrics such as R² score and absolute error.
  • preprocess.py: provides utility functions for converting images and spectral data into PyTorch tensors, including preprocessing steps like resizing, converting to grayscale, and interpolating spectra, with example usage for generating and saving these tensors.
  • scaler.py: defines a custom SpectrumNormalizer class that separately scales spectral data in 400-800nm and 800-1800nm ranges using MinMaxScaler
  • train.py: implements the training and evaluation pipeline for models (e.g., MIXModel), including training/validation loops, optimizer and scheduler configuration, checkpoint saving, command-line argument handling, and logging with wandb.
  • utils.py:provides a collection of utility functions supporting various workflow stages, including data splitting (train/valid/test), model evaluation, parameter counting, code backup, spectral index retrieval, folder creation, and spectrum interpolation.
  • check_results_final.ipynb: evaluates the performance of a pre-trained ResNet model on spectral data, including calculating R² scores, analyzing absolute errors for visible and infrared bands, and visualizing comparisons between actual and predicted spectra.

Model Types

1. Images as input

2. Parameters as input

3. Mix input (images + parameters)

  • Mix_Model: Following the methods mentioned in the paper Deep learning modeling approach for metasurfaces with high degrees of freedom, we use a CNN to extract image features and employ NTN (Neural Tensor Network) along with spatial tiling to extract and expand the dimensionality of parameters.
  • Mix_Model_with_Resnet: Building upon above MIX_Model, we now use ResNet18 for image feature extraction.

Visualizations (selected)

Mix_Model:

The training performance of the model is as follows. It can be seen that its loss on the validation set is lower than that of MLP and Resnet:

Image 2

Its R² score is also close to 0.9:

Image 2

It can also accurately predict most spectra:

Image 2

Image 2

Image 2

Mix_Model_with_Resnet:

The training performance of the model is as follows. Like the Mix_Model, it reaches the minimum value after 2000 epochs of training, and then starts to slightly overfit:

Image 2

The training performance is comparable to the previous Mix_Model with a slight improvement, and the R² score has increased by 0.1:

Image 2

The spectral prediction results are as follows:

Image 2

Image 2

Usage

  1. Ensure that the required dependency libraries are installed.
  2. Prepare the data and store it as .pt files.
  3. Modify the parameters in the train.py or inject through command line as needed.
  4. Run train.py to start training the model.
  5. Use check_results_final.ipynb to evaluate and analyze the model.

About

ODAM: A data augmentation model to reproduce optical data (e.g., spectra) and mitigate time-consuming simulation process

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published