Skip to content

chen0040/mxnet-gan

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mxnet-gan

Some GAN models that i studied while trying to learn MXNet.

  • Deep Convolution GAN
  • Pixel-to-Pixel GAN that performs image-to-image translation

Currently this repository is just a bunch of algorithms re-produced from http://gluon.mxnet.io, with a simple wrapper to make them modular.

Usage

Deep Convolution GAN

To run DCGan using the LFW dataset, run the following command:

python demo/dcgan_train.py

The demo/dcgan_train.py sample codes are shown below:

import os
import sys
import mxnet as mx


def patch_path(path):
    return os.path.join(os.path.dirname(__file__), path)


def main():
    sys.path.append(patch_path('..'))

    data_dir_path = patch_path('data/lfw_dataset')
    output_dir_path = patch_path('models')

    from mxnet_gan.library.dcgan import DCGan
    from mxnet_gan.data.lfw_data_set import download_lfw_dataset_if_not_exists

    download_lfw_dataset_if_not_exists(data_dir_path)

    gan = DCGan(model_ctx=mx.gpu(0), data_ctx=mx.gpu(0))

    gan.fit(data_dir_path=data_dir_path, model_dir_path=output_dir_path)


if __name__ == '__main__':
    main()

The trained models will be saved into demo/models folder with prefix "dcgan-*"

To run the trained models to generate new images:

python demo/dcgan_generate.py

The demo/dcgan_generate.py sample codes are shown below:

import os
import sys
import mxnet as mx


def patch_path(path):
    return os.path.join(os.path.dirname(__file__), path)


def main():
    sys.path.append(patch_path('..'))

    model_dir_path = patch_path('models')
    output_dir_path = patch_path('output')

    from mxnet_gan.library.dcgan import DCGan
    gan = DCGan(model_ctx=mx.gpu(0), data_ctx=mx.gpu(0))
    gan.load_model(model_dir_path)
    gan.generate(num_images=8, output_dir_path=output_dir_path)


if __name__ == '__main__':
    main()

Pixel-to-Pixel GAN

To run Pixel2PixelGan using the facade dataset dataset, run the following command:

python demo/pixel2pixel_gan_train.py

The demo/pixel2pixel_gan_train.py sample codes are shown below:

import os
import sys
import mxnet as mx


def patch_path(path):
    return os.path.join(os.path.dirname(__file__), path)


def main():
    sys.path.append(patch_path('..'))

    output_dir_path = patch_path('models')

    from mxnet_gan.library.pixel2pixel import Pixel2PixelGan
    from mxnet_gan.data.facades_data_set import load_image_pairs

    img_pairs = load_image_pairs(patch_path('data/facades'))
    gan = Pixel2PixelGan(model_ctx=mx.gpu(0), data_ctx=mx.gpu(0))
    gan.img_width = 64  # default value is 256, too large for my graphics card memory
    gan.img_height = 64  # default value is 256, too large for my graphics card memory
    gan.num_down_sampling = 5  # default value is 8, too large for my graphics card memory

    gan.fit(image_pairs=img_pairs, model_dir_path=output_dir_path)


if __name__ == '__main__':
    main()

The trained models will be saved into demo/models folder with prefix "pixel-2-pixel-gan-*"

To run the trained models to generate new images:

python demo/pixel2pixel_gan_generate.py

The demo/pixel2pixel_gan_generate.py sample codes are shown below:

import os
import sys
import mxnet as mx
from random import shuffle
import numpy as np


def patch_path(path):
    return os.path.join(os.path.dirname(__file__), path)


def main():
    sys.path.append(patch_path('..'))

    model_dir_path = patch_path('models')

    from mxnet_gan.library.pixel2pixel import Pixel2PixelGan
    from mxnet_gan.data.facades_data_set import load_image_pairs
    from mxnet_gan.library.image_utils import load_image, visualize

    img_pairs = load_image_pairs(patch_path('data/facades'))

    gan = Pixel2PixelGan(model_ctx=mx.gpu(0), data_ctx=mx.gpu(0))
    gan.load_model(model_dir_path)

    shuffle(img_pairs)

    for source_img_path, _ in img_pairs[:20]:
        source_img = load_image(source_img_path, gan.img_width, gan.img_height)
        target_img = gan.generate(source_image=source_img)
        img = mx.nd.concat(source_img.as_in_context(gan.model_ctx), target_img, dim=2)
        # img = ((img.asnumpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
        visualize(img)


if __name__ == '__main__':
    main()

Below is some output images generated:

| | | | | | | | | | | |

Note

Training with GPU

Note that the default training scripts in the demo folder use GPU for training, therefore, you must configure your graphic card for this (or remove the "model_ctx=mxnet.gpu(0)" in the training scripts).

About

My collection of GAN implemented using MXNet

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages