This project is a collaboration with Jilong Yi, a talented graduate student.
The goal of the project is to achieve high-accuracy forward prediction from images or/and structural parameters to spectral data, with the aim of augmenting the database and saving substantial time otherwise spent on FDTD simulations.
Parameters:
Shape | Params |
---|---|
circle | a,b,φ,Px,Py |
rec | W,L,φ,Px,Py |
double_rec | W1,L1,W2,L2,φ,Px,Py |
double_ellipse | a,b,θ,φ,Px,Py |
ring | R,r,θ,φ,Px,Py |
cross | W1,L1,W2,L2,offset,φ,Px,Py |
lack_rec | W,L,α,β,γ,φ,Px,Py |
The image size is 900px*900px with a white background, where yellow represents the period size and red represents the structure size.
The input parameters consist of 14 parameters for different structures, among which the three parameters Px, Py, and Phi are common attributes shared by all different structures. For the remaining attributes, they should be filled in if they exist, and set to 0 if they do not, as shown in the table below:
W1 | L1 | W2 | L2 | L | W | alpha | beta | gama | A | B | phi | Px | Py |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
100.0 | 104.0 | 214.0 | 219.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 72 | 356 | 421 |
...... | ...... | ...... | ...... | ...... | ...... | ...... | ...... | ...... | ...... | ...... | ...... | ...... | ...... |
The spectral data is generated by FDTD simulation, covering the wavelength range of 400~1800nm with 500 data points. The indices of the spectrum can be referenced in data/light_source.txt.
- dataset.py: defines multiple PyTorch Dataset classes (ResnetData, ResnetData2, MLP_Data, MixData) to handle different types of data (e.g., image, spectral, periodic data, CSV data) with preprocessing (e.g., scaling) and train/val/test splitting, along with corresponding functions to generate DataLoaders for model training.
- loss.py: defines two custom loss functions for spectral data in PyTorch, where CustomLoss combines MSE with peak error penalties and CustomLoss_2 splits the spectrum to apply different loss calculations (MSE + max value deviation).
- model.py: defines a variety of PyTorch neural network architectures, including modified ResNets, RNNs, MLPs, and hybrid models (e.g., cross_model, MIXModel), designed to handle diverse inputs like images, periodic data, and parameters.
- plotting_utils.py: contains functions for visualizing and analyzing predicted vs. raw data (e.g., parameters, CIE color values) using various plots (scatter plots, boxplots, histograms, etc.) and metrics such as R² score and absolute error.
- preprocess.py: provides utility functions for converting images and spectral data into PyTorch tensors, including preprocessing steps like resizing, converting to grayscale, and interpolating spectra, with example usage for generating and saving these tensors.
- scaler.py: defines a custom SpectrumNormalizer class that separately scales spectral data in 400-800nm and 800-1800nm ranges using MinMaxScaler
- train.py: implements the training and evaluation pipeline for models (e.g., MIXModel), including training/validation loops, optimizer and scheduler configuration, checkpoint saving, command-line argument handling, and logging with wandb.
- utils.py:provides a collection of utility functions supporting various workflow stages, including data splitting (train/valid/test), model evaluation, parameter counting, code backup, spectral index retrieval, folder creation, and spectrum interpolation.
- check_results_final.ipynb: evaluates the performance of a pre-trained ResNet model on spectral data, including calculating R² scores, analyzing absolute errors for visible and infrared bands, and visualizing comparisons between actual and predicted spectra.
- Resnet50:Using ResNet50 as the image feature extractor and connecting an external linear layer to realize forward prediction from images to spectra.
- ResNetRNN:leverages the method proposed in Finding the optical properties of plasmonic structures by image processing using a combination of convolutional neural networks and recurrent neural networks
- MLP: Use simple linear layers to realize the prediction.
- MLP—CNN: adopt a combination of MLP and CNN for prediction, leveraging the method proposed in Deep learning for accelerated all-dielectric metasurface design
- Mix_Model: Following the methods mentioned in the paper Deep learning modeling approach for metasurfaces with high degrees of freedom, we use a CNN to extract image features and employ NTN (Neural Tensor Network) along with spatial tiling to extract and expand the dimensionality of parameters.
- Mix_Model_with_Resnet: Building upon above MIX_Model, we now use ResNet18 for image feature extraction.
The training performance of the model is as follows. It can be seen that its loss on the validation set is lower than that of MLP and Resnet:
Its R² score is also close to 0.9:
It can also accurately predict most spectra:
The training performance of the model is as follows. Like the Mix_Model, it reaches the minimum value after 2000 epochs of training, and then starts to slightly overfit:
The training performance is comparable to the previous Mix_Model with a slight improvement, and the R² score has increased by 0.1:
The spectral prediction results are as follows:
- Ensure that the required dependency libraries are installed.
- Prepare the data and store it as .pt files.
- Modify the parameters in the train.py or inject through command line as needed.
- Run train.py to start training the model.
- Use check_results_final.ipynb to evaluate and analyze the model.