Enhancing GANs with MMD Neural Architecture Search, PMish Activation Function, and Adaptive Rank Decomposition (MMD-PMish-NAS-GAN)
This repository contains code for our 2024 IEEE ACCESS Journal paper "Enhancing GANs with MMD Neural Architecture Search, PMish Activation Function and Adaptive Rank Decomposition," by Prasanna Reddy Pulakurthi, Mahsa Mozaffari, Sohail Dianat, Jamison Heard, Raghuveer Rao, and Majid Rabbani. [PDF] [Paper] [Website]
Keywords: Generative Adversarial Network (GAN), Maximum Mean Discrepancy (MMD), Activation functions, Parametric Mish (PMish), Neural Architecture Search (NAS), Tensor Decomposition.
This project presents MMD-PMish-NAS, a framework that enhances GAN performance by integrating:
-
Parametric Mish (PMish) activation,
-
A modified MMD-GAN repulsive loss, and
-
Neural Architecture Search (NAS) with Adaptive Rank Decomposition (ARD).
-
Improved convergence stability and training efficiency.
-
State-of-the-art FID scores across CIFAR-100, STL-10, and CelebA-64 datasets.
Read the full paper or view the project website.
Example images generated by MMD-PMish-NAS across different datasets.
Graphical Abstract: Overview of MMD-PMish-NAS pipeline integrating activation optimization, loss modification, and architecture compression.
This is an implementation of the PMish Activation function using PyTorch. It combines the Tanh
and Softplus
functions with a learnable parameter, beta
.
import torch
import torch.nn as nn
import torch.nn.functional as F
class PMishActivation(nn.Module):
def __init__(self):
super(PMishActivation, self).__init__()
self.beta = nn.Parameter(torch.ones(1).cuda())
def forward(self, x):
beta_x = self.beta * x
return x * torch.tanh(F.softplus(beta_x) / self.beta)
-
Clone this repository.
git clone https://github.com/PrasannaPulakurthi/MMD-PMish-NAS.git cd MMD-PMish-NAS
-
Install requirements using Python 3.9.
conda create -n mmd-nas python=3.9 conda activate mmd-nas pip install -r requirements.txt
-
Install PyTorch1 and Tensorflow2 with CUDA.
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
To install other PyTorch versions compatible with your CUDA. Install PyTorch
Files can be found in Google Drive.
-
Download the pre-calculated statistics to ./fid_stat for calculating the FID from here.
-
Download the pre-trained models to ./exps from the exps folder found here.
-
Download the trained generative models from here to ./exps/train/pmishact_large_cifar10_xx/Model
mkdir -p exps/train/pmishact_large_cifar10_xx/Model
-
To test the trained model, run the command found in scripts/test_arch.sh
python MGPU_test_arch.py --random_seed 33333 --gpu_ids 0 --num_workers 1 --dataset cifar10 --bottom_width 4 --img_size 32 --arch arch_cifar10 --draw_arch False --checkpoint train/pmishact_large_cifar10_33333_2024_04_18_19_59_27 --genotypes_exp arch_cifar10 --latent_dim 120 --gf_dim 256 --num_eval_imgs 50000 --exp_name test/pmishact_large_cifar10 --act pmishact
-
Train the weights of the generative model with the searched architecture (the architecture is saved in ./exps/arch_cifar10/Genotypes/latest_G.npy). Run the command found in scripts/train_arch.sh
python MGPU_train_arch.py --gpu_ids 0 --num_workers 1 --gen_bs 128 --dis_bs 128 --dataset cifar10 --bottom_width 4 --img_size 32 --max_epoch_G 500 --n_critic 1 --arch arch_cifar10 --draw_arch False --genotypes_exp arch_cifar10 --latent_dim 120 --gf_dim 256 --df_dim 512 --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --val_freq 5 --num_eval_imgs 50000 --exp_name train/arch_train_cifar10_large --act pmishact --modified_mmd True
-
To use AdversarialNAS to search for the best architecture, run the command found in scripts/search_arch_cifar10.sh
python MGPU_search_arch.py --gpu_ids 0 --gen_bs 128 --dis_bs 128 --dataset cifar10 --bottom_width 4 --img_size 32 --max_epoch_G 25 --arch search_both_cifar10 --latent_dim 120 --gf_dim 160 --df_dim 80 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 5 --derive_freq 1 --derive_per_epoch 16 --draw_arch False --exp_name search/bs120-dim160 --num_workers 8 --gumbel_softmax True
-
Apply the ARD to find the best ranks for each layer.
To find the FID score for each layer and a candidate rank, run the Python file with the following command.
python scripts/findrank.py
Run the ARD.ipynb Jupyter notebook file to find the optimal ranks.
PF Layer 1 Layer 2 Layer 3 Layer 4 Layer 5 Layer 6 Layer 7 Layer 8 Layer 9 Layer 10 Layer 11 Layer 12 Layer 13 l1 l2 l3 1/1 128 128 256 128 256 128 128 128 512 128 128 128 768 nc 2 2 1/5 128 128 768 128 768 128 128 256 512 256 128 128 768 nc 2 2 1/10 128 128 768 128 768 128 128 256 768 256 768 768 nc nc 2 2 1/15 128 128 768 128 768 128 128 512 nc 256 768 768 nc nc 2 2 1/20 128 128 768 128 nc 128 128 nc nc 512 768 768 nc nc 2 2 -
Compress and fine-tune all the Convolutional Layers according to ARD.
python MGPU_cpcompress_arch.py --gpu_ids 0 --num_workers 1 --dataset cifar10 --bottom_width 4 --img_size 32 --arch arch_cifar10 --draw_arch False --freeze_before_compressed --freeze_layers l1 l2 l3 --genotypes_exp arch_cifar10 --latent_dim 120 --gf_dim 128 --df_dim 512 --num_eval_imgs 50000 --checkpoint train/pmishact_large_cifar10_33333_2024_04_18_19_59_27 --exp_name compress/rp_1 --val_freq 1 --gen_bs 128 --dis_bs 128 --beta1 0.0 --beta2 0.9 --byrank --rank 128 128 768 128 768 128 128 512 nc 256 768 768 nc nc 2 2 --layers cell1.c0.ops.0.op.1 cell1.c1.ops.0.op.1 cell1.c2.ops.0.op.1 cell1.c3.ops.0.op.1 cell2.c0.ops.0.op.1 cell2.c2.ops.0.op.1 cell2.c3.ops.0.op.1 cell2.c4.ops.0.op.1 cell3.c0.ops.0.op.1 cell3.c1.ops.0.op.1 cell3.c2.ops.0.op.1 cell3.c3.ops.0.op.1 cell3.c4.ops.0.op.1 l1 l2 l3 --max_epoch_G 300 --act pmishact
-
To test the compressed network, download the compressed model from here to ./exps/compress/cifar10_small_1by15_xx/Model
python MGPU_test_cpcompress.py --gpu_ids 0 --num_workers 1 --dataset cifar10 --bottom_width 4 --img_size 32 --arch arch_cifar10 --draw_arch False --checkpoint compress/cifar10_small_1by15_2024_05_05_02_38_59 --genotypes_exp arch_cifar10 --latent_dim 120 --gf_dim 128 --num_eval_imgs 50000 --eval_batch_size 100 --exp_name test/compress_cifar10_small --act pmishact --byrank
Please consider citing our paper in your publications if it helps your research. The following is a BibTeX reference.
@INPROCEEDINGS{10446488,
author={Pulakurthi, Prasanna Reddy and Mozaffari, Mahsa and Dianat, Sohail A. and Rabbani, Majid and Heard, Jamison and Rao, Raghuveer},
booktitle={ICASSP 2024 - 2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Enhancing GAN Performance Through Neural Architecture Search and Tensor Decomposition},
year={2024},
volume={},
number={},
pages={7280-7284},
keywords={Training;Performance evaluation;Tensors;Image coding;Image synthesis;Image edge detection;Computer architecture;Neural Architecture Search;Maximum Mean Discrepancy;Generative Adversarial Networks},
doi={10.1109/ICASSP48485.2024.10446488}}
@ARTICLE{10732016,
author={Pulakurthi, Prasanna Reddy and Mozaffari, Mahsa and Dianat, Sohail and Heard, Jamison and Rao, Raghuveer and Rabbani, Majid},
journal={IEEE Access},
title={Enhancing GANs With MMD Neural Architecture Search, PMish Activation Function, and Adaptive Rank Decomposition},
year={2024},
volume={12},
number={},
pages={174222-174244},
keywords={Generative adversarial networks;Training;Generators;Image coding;Acute respiratory distress syndrome;Tensors;Standards;Neural networks;Image synthesis;Adaptive systems;Activation function;generative adversarial network;maximum mean discrepancy;neural architecture search;tensor decomposition},
doi={10.1109/ACCESS.2024.3485557}}
Codebase from MMD-AdversarialNAS-GAN, AdversarialNAS, and Tensorly.