From b8e7b28d7ff1d0fffc91e496dd773674304512a2 Mon Sep 17 00:00:00 2001 From: Vedanta Jha Date: Sat, 31 Aug 2019 22:47:03 +0530 Subject: [PATCH 1/5] Add files via upload Update README.md Add files via upload Delete epoch132_gt.png Delete epoch132_img.png Delete epoch132_pred.png Add files via upload Update README.md Update README.md Update README.md Update README.md Update README.md Delete image_pool.py Delete network.py Delete train.py Delete base_options.py Delete train_options.py Delete pred6.png Delete epoch138_pred.png Delete epoch132_gt.png Delete epoch132_img.png Delete epoch132_pred.png Delete epoch134_gt.png Delete epoch134_img.png Delete epoch134_pred.png Delete epoch138_gt.png Delete epoch138_img.png Delete pred5.png Delete gt1.png Delete gt2.png Delete gt3.png Delete gt4.png Delete gt5.png Delete gt6.png Delete img1.png Delete img2.png Delete img3.png Delete img4.png Delete img5.png Delete img6.png Delete pred1.png Delete pred2.png Delete pred3.png Delete pred4.png Delete README.md Create train.py Add files via upload Delete single_dataset.py Add files via upload Delete test_model.py Add files via upload Update README.md Update networks.py Delete deeplabv3.py Update custom_dataset_dataloader.py Create README.md Update README.md Update custom_dataset_dataloader.py Update models.py Update custom_dataset_dataloader.py Update test.py Update test.py Update test.py Rename modules/ximgproc/src/alphagan_matting/utils/image_pool.py to modules/ximgproc/src/alphagan_matting/util/image_pool.py Rename modules/ximgproc/src/alphagan_matting/utils/utils.py to modules/ximgproc/src/alphagan_matting/util/util.py Update networks.py Update test_dataset.py Add files via upload Add files via upload Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update base_data_loader.py Removing whitespaces Removing whitespaces Removing whitespaces Removing whitespaces Removing whitespaces Removing whitespaces --- modules/ximgproc/README.md | 1 + .../ximgproc/src/alphagan_matting/README.md | 93 +++ .../alphagan_matting/data/base_data_loader.py | 9 + .../src/alphagan_matting/data/base_dataset.py | 45 ++ .../data/custom_dataset_dataloader.py | 45 ++ .../src/alphagan_matting/data/data_loader.py | 6 + .../data/generated_dataset_simple.py | 189 ++++++ .../src/alphagan_matting/data/image_folder.py | 68 ++ .../src/alphagan_matting/data/test_dataset.py | 46 ++ .../src/alphagan_matting/models/basemodel.py | 60 ++ .../src/alphagan_matting/models/models.py | 19 + .../src/alphagan_matting/models/networks.py | 636 ++++++++++++++++++ .../src/alphagan_matting/models/simple_gan.py | 220 ++++++ .../src/alphagan_matting/models/test_model.py | 61 ++ .../alphagan_matting/options/base_options.py | 75 +++ .../alphagan_matting/options/test_options.py | 13 + .../alphagan_matting/options/train_options.py | 29 + modules/ximgproc/src/alphagan_matting/test.py | 36 + .../ximgproc/src/alphagan_matting/train.py | 62 ++ .../src/alphagan_matting/util/image_pool.py | 34 + .../src/alphagan_matting/util/util.py | 61 ++ 21 files changed, 1808 insertions(+) create mode 100644 modules/ximgproc/src/alphagan_matting/README.md create mode 100644 modules/ximgproc/src/alphagan_matting/data/base_data_loader.py create mode 100644 modules/ximgproc/src/alphagan_matting/data/base_dataset.py create mode 100644 modules/ximgproc/src/alphagan_matting/data/custom_dataset_dataloader.py create mode 100644 modules/ximgproc/src/alphagan_matting/data/data_loader.py create mode 100644 modules/ximgproc/src/alphagan_matting/data/generated_dataset_simple.py create mode 100644 modules/ximgproc/src/alphagan_matting/data/image_folder.py create mode 100644 modules/ximgproc/src/alphagan_matting/data/test_dataset.py create mode 100644 modules/ximgproc/src/alphagan_matting/models/basemodel.py create mode 100644 modules/ximgproc/src/alphagan_matting/models/models.py create mode 100644 modules/ximgproc/src/alphagan_matting/models/networks.py create mode 100644 modules/ximgproc/src/alphagan_matting/models/simple_gan.py create mode 100644 modules/ximgproc/src/alphagan_matting/models/test_model.py create mode 100644 modules/ximgproc/src/alphagan_matting/options/base_options.py create mode 100644 modules/ximgproc/src/alphagan_matting/options/test_options.py create mode 100644 modules/ximgproc/src/alphagan_matting/options/train_options.py create mode 100644 modules/ximgproc/src/alphagan_matting/test.py create mode 100644 modules/ximgproc/src/alphagan_matting/train.py create mode 100644 modules/ximgproc/src/alphagan_matting/util/image_pool.py create mode 100644 modules/ximgproc/src/alphagan_matting/util/util.py diff --git a/modules/ximgproc/README.md b/modules/ximgproc/README.md index 59e744e0238..9148141fd69 100644 --- a/modules/ximgproc/README.md +++ b/modules/ximgproc/README.md @@ -16,3 +16,4 @@ Extended Image Processing - Pei&Lin Normalization - Ridge Detection Filter - Binary morphology on run-length encoded images +- AlphaGan matting diff --git a/modules/ximgproc/src/alphagan_matting/README.md b/modules/ximgproc/src/alphagan_matting/README.md new file mode 100644 index 00000000000..d42f85362b9 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/README.md @@ -0,0 +1,93 @@ +## Architecture ## + +#### Generator #### +A decoder encoder based architecture is used. + + +There are 2 options for the generator encoder. + +a. Resnet50 minus the last 2 layers +b. Resnet50 + ASPP module + +The Decoder network of the Generator network has seven upsampling convolutional blocks. +Each upsampling convolutional block has an upsampling layer, followed by a convolutional layer, a batch normalization layer and a ReLU activation function. + +#### Discriminator #### +The discriminator used here is the PatchGAN discriminator. The implementation here is inspired from the implementation of CycleGAN from
+https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix + +Again, 2 different types of discriminator are used +a. N Layer Patch gan discriminator, where the size of the patch is NxN, it is taken as 3x3 here +b. Pixel patch Patch gan discriminator, the discriminator classsifies every pixel. + +## How to use +Use the dataroot argument to enter the directory where you have stored the data. +Structure the data in the following way. + +train
+-alpha -bg -fg + +test
+-fg -trimap + +The background I have used here is the MSCOCO dataset. + + +To train the model using Resnet50 without ASPP module + +`!python train.py --dataroot ./ --model simple --dataset_mode generated_simple --which_model_netG resnet50 --name resnet50` + +To test the model using Resnet without ASPP module + +`!python test.py --dataroot ./ --dataset_mode single --which_model_netG resnet50 --ntest 8 --model test --name resnet50` + +To train the model using Resnet50 using ASPP module + +`!python train.py --dataroot ./ --model simple --dataset_mode generated_simple --which_model_netG resnet50ASPP --name resnet50ASPP` + +To test the model using Resnet50 using ASPP module + +`!python test.py --dataroot ./ --dataset_mode single --which_model_netG resnet50ASPP --ntest 8 --model test` + +## Results + +#### Input: +![input](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/2.png) + +#### Trimap: +![Trimap](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey_tri.png) + + +#### AlphaGAN matting : +##### Generator:Resnet50,Discriminator:N Layer Patch GAN +![Output2](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey_resnet50.png) + +##### Generator:Resnet50,Discriminator:Pixel Patch GAN +![Output3](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey.png) + +##### Generator:Resnet50 + ASPP module,Discriminator:N Layer Patch GAN +![Output4](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey_deeplab.png) + + +### Comparing with the Original implementation +(Average Rank on alphamatting.com has been shown) + +| Error type | Original implementation | Resnet50 +N Layer | Resnet50 + Pixel | Resnet50 + ASPP module | +| ----------- | ------------------------ | ------------------- | ----------------- | -----------| +| Sum of absolute differences | 11.7 | 42.8 | 43.8 | 53 | +| Mean square error | 15 | 45.8 | 45.6 | 54.2 | +| Gradient error | 14 | 52.9 | 52.7 | 55 | +| Connectivity error | 29.6 | 23.3 | 22.6 | 32.8 | + + +### Training dataset used +I used the training dataset created by me using the software known as gimp. +[Link to created dataset](https://drive.google.com/open?id=1zQbk2Cu7QOBwzg4vVGqCWJwHGTwGppFe) + +### What could be wrong ? + +1.The dataset I created is not perfect. I tried to make it as perfect by marking every pixel in some cases as well, but it still ain't perfect. + +2. The output of the generator is 320x320, then the alpha matte is resized to the original image. Maybe there is a loss in resizing the image, and a better upsampling method might just imporove the outputs. + +3. The major reason the implementation isn't as good as original is that the author hasn't clearly mentioned how the skip connections are used. I am very sure about the architecture of the encoder and decoder, but the only thing I am unsure about is how skip connections are used. diff --git a/modules/ximgproc/src/alphagan_matting/data/base_data_loader.py b/modules/ximgproc/src/alphagan_matting/data/base_data_loader.py new file mode 100644 index 00000000000..d2041fc8c66 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/data/base_data_loader.py @@ -0,0 +1,9 @@ +class BaseDataLoader(): + def __init__(self): + pass + def initialize(self, opt): + self.opt = opt + pass + + def load_data(): + return None diff --git a/modules/ximgproc/src/alphagan_matting/data/base_dataset.py b/modules/ximgproc/src/alphagan_matting/data/base_dataset.py new file mode 100644 index 00000000000..a061a05edb0 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/data/base_dataset.py @@ -0,0 +1,45 @@ +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms + +class BaseDataset(data.Dataset): + def __init__(self): + super(BaseDataset, self).__init__() + + def name(self): + return 'BaseDataset' + + def initialize(self, opt): + pass + +def get_transform(opt): + transform_list = [] + if opt.resize_or_crop == 'resize_and_crop': + osize = [opt.loadSize, opt.loadSize] + transform_list.append(transforms.Scale(osize, Image.BICUBIC)) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'crop': + transform_list.append(transforms.RandomCrop(opt.fineSize)) + elif opt.resize_or_crop == 'scale_width': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.fineSize))) + elif opt.resize_or_crop == 'scale_width_and_crop': + transform_list.append(transforms.Lambda( + lambda img: __scale_width(img, opt.loadSize))) + transform_list.append(transforms.RandomCrop(opt.fineSize)) + + if opt.isTrain and not opt.no_flip: + transform_list.append(transforms.RandomHorizontalFlip()) + + transform_list += [transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), + (0.5, 0.5, 0.5))] + return transforms.Compose(transform_list) + +def __scale_width(img, target_width): + ow, oh = img.size + if (ow == target_width): + return img + w = target_width + h = int(target_width * oh / ow) + return img.resize((w, h), Image.BICUBIC) diff --git a/modules/ximgproc/src/alphagan_matting/data/custom_dataset_dataloader.py b/modules/ximgproc/src/alphagan_matting/data/custom_dataset_dataloader.py new file mode 100644 index 00000000000..5d243c6c57f --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/data/custom_dataset_dataloader.py @@ -0,0 +1,45 @@ +import torch.utils.data +from data.base_data_loader import BaseDataLoader + + +def CreateDataset(opt): + dataset = None + + elif opt.dataset_mode == 'testData': + from data.test_dataset import TestDataset + dataset = TestDataset() + elif opt.dataset_mode == 'generated_simple': + from data.generated_dataset_simple import GeneratedDatasetSimple + dataset = GeneratedDatasetSimple() + else: + raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) + + print("dataset [%s] was created" % (dataset.name())) + dataset.initialize(opt) + return dataset + + +class CustomDatasetDataLoader(BaseDataLoader): + def name(self): + return 'CustomDatasetDataLoader' + + def initialize(self, opt): + BaseDataLoader.initialize(self, opt) + self.dataset = CreateDataset(opt) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batchSize, + shuffle=not opt.serial_batches, + num_workers=int(opt.nThreads)) + + def load_data(self): + return self + + def __len__(self): + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + for i, data in enumerate(self.dataloader): + if i >= self.opt.max_dataset_size: + break + yield data diff --git a/modules/ximgproc/src/alphagan_matting/data/data_loader.py b/modules/ximgproc/src/alphagan_matting/data/data_loader.py new file mode 100644 index 00000000000..22b6a8f111b --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/data/data_loader.py @@ -0,0 +1,6 @@ +def CreateDataLoader(opt): + from data.custom_dataset_data_loader import CustomDatasetDataLoader + data_loader = CustomDatasetDataLoader() + print(data_loader.name()) + data_loader.initialize(opt) + return data_loader diff --git a/modules/ximgproc/src/alphagan_matting/data/generated_dataset_simple.py b/modules/ximgproc/src/alphagan_matting/data/generated_dataset_simple.py new file mode 100644 index 00000000000..bd35dcd94c4 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/data/generated_dataset_simple.py @@ -0,0 +1,189 @@ +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image +import PIL +import random +import scipy.ndimage +import numpy as np +import math +# import pbcvt +# import colour_transfer + +class GeneratedDatasetSimple(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_AB = os.path.join(opt.dataroot, opt.phase) + self.dir_alpha = os.path.join(self.dir_AB, 'alpha') + self.dir_fg = os.path.join(self.dir_AB, 'fg') + self.dir_bg = os.path.join(self.dir_AB, 'bg') + self.alpha_paths = sorted(make_dataset(self.dir_alpha)) + self.fg_paths = sorted(make_dataset(self.dir_fg)) + self.bg_paths = make_dataset(self.dir_bg) + self.alpha_size = len(self.alpha_paths) + self.bg_size = len(self.bg_paths) + + + def __getitem__(self, index): + index = index % self.alpha_size + alpha_path = self.alpha_paths[index] + fg_path = self.fg_paths[index] + index_bg = random.randint(0, self.bg_size - 1) + bg_path = self.bg_paths[index_bg] + + + A_bg = Image.open(bg_path).convert('RGB') + A_fg = Image.open(fg_path).convert('RGB') + + A_alpha = Image.open(alpha_path).convert('L') + assert A_alpha.mode == 'L' + + + A_trimap = self.generate_trimap(A_alpha) + + # A_bg = self.resize_bg(A_bg, A_fg) + w_bg, h_bg = A_bg.size + if w_bg < 321 or h_bg < 321: + x = w_bg if w_bg < h_bg else h_bg + ratio = 321/float(x) + A_bg = A_bg.resize((int(np.ceil(w_bg*ratio)+1),int(np.ceil(h_bg*ratio)+1)), Image.BICUBIC) + w_bg, h_bg = A_bg.size + assert w_bg > 320 and h_bg > 320, '{} {}'.format(w_bg, h_bg) + x = random.randint(0, w_bg-320-1) + y = random.randint(0, h_bg-320-1) + A_bg = A_bg.crop((x,y, x+320, y+320)) + + crop_size = random.choice([320,480,640]) + # crop_size = random.choice([320,400,480,560,640,720]) + crop_center = self.find_crop_center(A_trimap) + start_index_height = max(min(A_fg.size[1]-crop_size, crop_center[0] - int(crop_size/2) + 1), 0) + start_index_width = max(min(A_fg.size[0]-crop_size, crop_center[1] - int(crop_size/2) + 1), 0) + + bbox = ((start_index_width,start_index_height,start_index_width+crop_size,start_index_height+crop_size)) + + # A_bg = A_bg.crop(bbox) + A_fg = A_fg.crop(bbox) + A_alpha = A_alpha.crop(bbox) + A_trimap = A_trimap.crop(bbox) + + if self.opt.which_model_netG == 'unet_256': + A_bg = A_bg.resize((256,256)) + A_fg = A_fg.resize((256,256)) + A_alpha = A_alpha.resize((256,256)) + A_trimap = A_trimap.resize((256,256)) + assert A_alpha.mode == 'L' + else: + A_bg = A_bg.resize((320,320)) + A_fg = A_fg.resize((320,320)) + A_alpha = A_alpha.resize((320,320)) + A_trimap = A_trimap.resize((320,320)) + assert A_alpha.mode == 'L' + + if random.randint(0, 1): + A_bg = A_bg.transpose(PIL.Image.FLIP_LEFT_RIGHT) + + if random.randint(0, 1): + A_fg = A_fg.transpose(PIL.Image.FLIP_LEFT_RIGHT) + A_alpha = A_alpha.transpose(PIL.Image.FLIP_LEFT_RIGHT) + A_trimap = A_trimap.transpose(PIL.Image.FLIP_LEFT_RIGHT) + + ## COLOR TRANSFER ## + # if random.randint(0, 2) != 0: + # A_old = A_fg + # target = np.array(A_fg) + # palette = np.array(A_palette) + # recolor = colour_transfer.runCT(target, palette) + # A_fg = Image.fromarray(recolor) + + if self.opt.which_direction == 'BtoA': + input_nc = self.opt.output_nc + output_nc = self.opt.input_nc + else: + input_nc = self.opt.input_nc + output_nc = self.opt.output_nc + + A_bg = transforms.ToTensor()(A_bg) + A_fg = transforms.ToTensor()(A_fg) + A_alpha = transforms.ToTensor()(A_alpha) + A_trimap = transforms.ToTensor()(A_trimap) + + return {'A_bg': A_bg, + 'A_fg': A_fg, + 'A_alpha': A_alpha, + 'A_trimap': A_trimap, + 'A_paths': alpha_path} + + def resize_bg(self, bg, fg): + bbox = fg.size + w = bbox[0] + h = bbox[1] + bg_bbox = bg.size + bw = bg_bbox[0] + bh = bg_bbox[1] + wratio = w / float(bw) + hratio = h / float(bh) + ratio = wratio if wratio > hratio else hratio + if ratio > 1: + bg = bg.resize((int(np.ceil(bw*ratio)+1),int(np.ceil(bh*ratio)+1)), Image.BICUBIC) + bg = bg.crop((0,0,w,h)) + + return bg + + # def generate_trimap(self, alpha): + # trimap = np.array(alpha) + # kernel_sizes = [val for val in range(5,40)] + # kernel = random.choice(kernel_sizes) + # trimap[np.where((scipy.ndimage.grey_dilation(alpha,size=(kernel,kernel)) - alpha!=0))] = 128 + + # return Image.fromarray(trimap) + def generate_trimap(self, alpha): + trimap = np.array(alpha) + grey = np.zeros_like(trimap) + kernel_sizes = [val for val in range(2,20)] + kernel = random.choice(kernel_sizes) + # trimap[np.where((scipy.ndimage.grey_dilation(alpha,size=(kernel,kernel)) - alpha!=0))] = 128 + grey = np.where(np.logical_and(trimap>0, trimap<255), 128, 0) + grey = scipy.ndimage.grey_dilation(grey, size=(kernel,kernel)) + trimap[grey==128] = 128 + + return Image.fromarray(trimap) + + def find_crop_center(self, trimap): + t = np.array(trimap) + target = np.where(t==128) + index = random.choice([i for i in range(len(target[0]))]) + return np.array(target)[:,index][:2] + + def rotatedRectWithMaxArea(self, w, h, angle): + """ + Given a rectangle of size wxh that has been rotated by 'angle' (in + radians), computes the width and height of the largest possible + axis-aligned rectangle (maximal area) within the rotated rectangle. + """ + if w <= 0 or h <= 0: + return 0,0 + width_is_longer = w >= h + side_long, side_short = (w,h) if width_is_longer else (h,w) + + # since the solutions for angle, -angle and 180-angle are all the same, + # if suffices to look at the first quadrant and the absolute values of sin,cos: + sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) + if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10: + # half constrained case: two crop corners touch the longer side, + # the other two corners are on the mid-line parallel to the longer line + x = 0.5*side_short + wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a) + else: + # fully constrained case: crop touches all 4 sides + cos_2a = cos_a*cos_a - sin_a*sin_a + wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a + + return wr,hr + + def __len__(self): + return len(self.alpha_paths) + + def name(self): + return 'GeneratedDataset' diff --git a/modules/ximgproc/src/alphagan_matting/data/image_folder.py b/modules/ximgproc/src/alphagan_matting/data/image_folder.py new file mode 100644 index 00000000000..898200b2274 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/data/image_folder.py @@ -0,0 +1,68 @@ +############################################################################### +# Code from +# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py +# Modified the original code so that it also loads images from the current +# directory as well as the subdirectories +############################################################################### + +import torch.utils.data as data + +from PIL import Image +import os +import os.path + +IMG_EXTENSIONS = [ + '.jpg', '.JPG', '.jpeg', '.JPEG', + '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', +] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def make_dataset(dir): + images = [] + assert os.path.isdir(dir), '%s is not a valid directory' % dir + + for root, _, fnames in sorted(os.walk(dir)): + for fname in fnames: + if is_image_file(fname): + path = os.path.join(root, fname) + images.append(path) + + return images + + +def default_loader(path): + return Image.open(path).convert('RGB') + + +class ImageFolder(data.Dataset): + + def __init__(self, root, transform=None, return_paths=False, + loader=default_loader): + imgs = make_dataset(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in: " + root + "\n" + "Supported image extensions are: " + + ",".join(IMG_EXTENSIONS))) + + self.root = root + self.imgs = imgs + self.transform = transform + self.return_paths = return_paths + self.loader = loader + + def __getitem__(self, index): + path = self.imgs[index] + img = self.loader(path) + if self.transform is not None: + img = self.transform(img) + if self.return_paths: + return img, path + else: + return img + + def __len__(self): + return len(self.imgs) diff --git a/modules/ximgproc/src/alphagan_matting/data/test_dataset.py b/modules/ximgproc/src/alphagan_matting/data/test_dataset.py new file mode 100644 index 00000000000..2caae08fbc0 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/data/test_dataset.py @@ -0,0 +1,46 @@ +import os.path +import torchvision.transforms as transforms +from data.base_dataset import BaseDataset, get_transform +from data.image_folder import make_dataset +from PIL import Image + + +class TestDataset(BaseDataset): + def initialize(self, opt): + self.opt = opt + self.root = opt.dataroot + self.dir_A = os.path.join(opt.dataroot,opt.phase) + #self.dir_AB = os.path.join(opt.dataroot, opt.phase) + self.dir_trimap = os.path.join(self.dir_A, 'trimap') + self.dir_fg = os.path.join(self.dir_A, 'fg') + + + #self.A_paths = make_dataset(self.dir_A) + self.fg_paths = make_dataset(self.dir_fg) + self.trimap_paths = make_dataset(self.dir_trimap) + self.fg_paths = sorted(self.fg_paths) + self.trimap_paths = sorted(self.trimap_paths) + self.transform = get_transform(opt) + + def __getitem__(self, index): + fg_path = self.fg_paths[index] + trimap_path = self.trimap_paths[index] + A_fg = Image.open(fg_path).convert('RGB') + A_trimap = Image.open(trimap_path).convert('L') + #A_fg = self.transform(A_fg) + #A_trimap = self.transform(A_trimap) + + A_fg = A_fg.resize((320,320)) + A_trimap = A_trimap.resize((320,320)) + A_fg = transforms.ToTensor()(A_fg) + A_trimap = transforms.ToTensor()(A_trimap) + + input_nc = self.opt.input_nc + + return {'A_fg': A_fg, 'A_trimap' : A_trimap,'A_paths': fg_paths} + + def __len__(self): + return len(self.A_paths) + + def name(self): + return 'TestDataset' diff --git a/modules/ximgproc/src/alphagan_matting/models/basemodel.py b/modules/ximgproc/src/alphagan_matting/models/basemodel.py new file mode 100644 index 00000000000..9b55afea195 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/models/basemodel.py @@ -0,0 +1,60 @@ +import os +import torch + + +class BaseModel(): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if len(gpu_ids) and torch.cuda.is_available(): + network.cuda(gpu_ids[0]) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + network.load_state_dict(torch.load(save_path)) + + # update learning rate (called once every epoch) + def update_learning_rate(self): + for scheduler in self.schedulers: + scheduler.step() + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) diff --git a/modules/ximgproc/src/alphagan_matting/models/models.py b/modules/ximgproc/src/alphagan_matting/models/models.py new file mode 100644 index 00000000000..b97c44fa163 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/models/models.py @@ -0,0 +1,19 @@ +import torch.nn as nn + +def create_model(opt): + model = None + print(opt.model) + + assert(opt.dataset_mode == 'generated_simple') + from .simple_gan import SimpleModel + model = SimpleModel() + elif opt.model == 'test': + #we only need the foreground and the trimap , hence using a slightly different data loader + assert(opt.dataset_mode == 'testData') + from .test_model import TestModel + model = TestModel() + else: + raise ValueError("Model [%s] not recognized." % opt.model) + model.initialize(opt) + print("model [%s] was created" % (model.name())) + return model diff --git a/modules/ximgproc/src/alphagan_matting/models/networks.py b/modules/ximgproc/src/alphagan_matting/models/networks.py new file mode 100644 index 00000000000..f733a3083b9 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/models/networks.py @@ -0,0 +1,636 @@ +import torch +import torch.nn as nn +from torch.nn import init +import torch.nn.functional as F +import functools +from torch.autograd import Variable +from torch.optim import lr_scheduler +import torchvision.models as models +import torch.utils.model_zoo as model_zoo +import numpy as np + +############################################################################### +# Functions +############################################################################### + + +def weights_init_normal(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.normal(m.weight.data, 0.0, 0.02) + elif classname.find('Linear') != -1: + init.normal(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + init.normal(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_xavier(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.xavier_normal(m.weight.data, gain=0.02) + elif classname.find('Linear') != -1: + init.xavier_normal(m.weight.data, gain=0.02) + elif classname.find('BatchNorm2d') != -1: + init.normal(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_kaiming(m): + classname = m.__class__.__name__ + # print(classname) + if classname.find('Conv') != -1: + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('Linear') != -1: + init.kaiming_normal(m.weight.data, a=0, mode='fan_in') + elif classname.find('BatchNorm2d') != -1: + init.normal(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def weights_init_orthogonal(m): + classname = m.__class__.__name__ + print(classname) + if classname.find('Conv') != -1: + init.orthogonal(m.weight.data, gain=1) + elif classname.find('Linear') != -1: + init.orthogonal(m.weight.data, gain=1) + elif classname.find('BatchNorm2d') != -1: + init.normal(m.weight.data, 1.0, 0.02) + init.constant(m.bias.data, 0.0) + + +def init_weights(net, init_type='normal'): + """ + Initialize network weights. + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + """ + + print('initialization method [%s]' % init_type) + if init_type == 'normal': + net.apply(weights_init_normal) + elif init_type == 'xavier': + net.apply(weights_init_xavier) + elif init_type == 'kaiming': + net.apply(weights_init_kaiming) + elif init_type == 'orthogonal': + net.apply(weights_init_orthogonal) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + + +def get_norm_layer(norm_type='instance'): + """ + Return a normalization layer + Parameters: + norm_type (str) -- the name of the normalization layer: batch | instance | none + For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). + For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) + elif norm_type == 'none': + norm_layer = None + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_scheduler(optimizer, opt): + """ + Return a learning rate scheduler + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: lambda | step | plateau + """ + if opt.lr_policy == 'lambda': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[], pretrain=True): + """ + Create a generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + which_model_netG (str) -- the architecture's name: resnet50 | resnet50ASPP + norm (str) -- the name of normalization layers used in the network: batch | instance | none + use_dropout (bool) -- if use dropout layers. + init_type (str) -- the name of our initialization method. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + Returns a generator + """ + netG = None + use_gpu = len(gpu_ids) > 0 + norm_layer = get_norm_layer(norm_type=norm) + + if use_gpu: + assert(torch.cuda.is_available()) + + + if which_model_netG == 'resnet50': + netG = ResnetX(id=50, gpu_ids=gpu_ids, pretrain=pretrain) + elif which_model_netG == 'resnet50ASPP': + netG = ResnetASPP(id=50,gpu_ids=gpu_ids,pretrain=pretrain) + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) + if len(gpu_ids) > 0: + netG.cuda(gpu_ids[0]) + + print('Using pretrained weights') + + return netG + + +def define_D(input_nc, ndf, which_model_netD, + n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]): + """ + Create a discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the first conv layer + which_model_netD (str) -- the architecture's name: basic | n_layers | pixel + n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' + norm (str) -- the type of normalization layers used in the network. + init_type (str) -- the name of the initialization method + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + Returns a discriminator + """ + netD = None + use_gpu = len(gpu_ids) > 0 + norm_layer = get_norm_layer(norm_type=norm) + + if use_gpu: + assert(torch.cuda.is_available()) + if which_model_netD == 'basic': + netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'n_layers': + netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + elif which_model_netD == 'pixel': + netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % + which_model_netD) + if use_gpu: + netD.cuda(gpu_ids[0]) + init_weights(netD, init_type=init_type) + return netD + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + + +############################################################################## +# Classes +############################################################################## + + +# Defines the GAN loss which uses either LSGAN or the regular GAN. +# When LSGAN is used, it is basically same as MSELoss, +# but it abstracts away the need to create the target label tensor +# that has the same size as the input +class GANLoss(nn.Module): + """ + Lgan(G,D)= log D(x)+log(1−D(C(G(x))) + where x is a real input: an image composited from the ground-truth alpha and foreground appended with the trimap. + C(y) is a composition function that takes the predicted alpha from G as an input and uses it to composite a fake image. + G tries to generate alphas that are close to the ground-truth alpha, while D tries to + distinguish real from fake composited images. + G therefore tries to minimize Lgan against the discriminator D, which tries to maximize it. + """ + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + target_tensor = self.get_target_tensor(input, target_is_real) + return self.loss(input, target_tensor) + + +class AlphaPredictionLoss(nn.Module): + """ + It is the absolute difference between the ground truth alpha values and the predicted alpha values at each pixel. + """ + def __init__(self): + super(AlphaPredictionLoss, self).__init__() + + def forward(self, input, target, trimap): + trimap_weights = torch.where(torch.eq(torch.ge(trimap, 0.4), torch.le(trimap, 0.6)), torch.ones_like(trimap), torch.zeros_like(trimap)) + unknown_region_size = trimap_weights.sum() + diff = torch.sqrt(torch.add(torch.pow(input - target, 2), 1e-12)) + return torch.mul(diff, trimap_weights).sum() / unknown_region_size + + +class CompLoss(nn.Module): + + """ + Compostion Loss : Absolute difference between the ground truth RGB colors and the predicted RGB colors composited + by the groundtruth foreground, the ground truth background and the predicted alpha mattes + """ + + def __init__(self): + super(CompLoss, self).__init__() + + def forward(self, input, target, trimap, fg, bg): + trimap_weights = torch.where(torch.eq(torch.ge(trimap, 0.4), torch.le(trimap, 0.6)), torch.ones_like(trimap), torch.zeros_like(trimap)) + unknown_region_size = trimap_weights.sum() + + + comp_target = torch.mul(target, fg) + torch.mul((1.0 - target), bg) + comp_input = torch.mul(input, fg) + torch.mul((1.0 - input), bg) + + diff = torch.sqrt(torch.add(torch.pow(comp_input - comp_target, 2), 1e-12)) + return torch.mul(diff, trimap_weights).sum() / unknown_region_size + + +class ResnetX(nn.Module): + def __init__(self, id=50, gpu_ids=[], pretrain=True): + super(ResnetX, self).__init__() + self.encoder = ResnetEncoder(id, gpu_ids, pretrain) + self.decoder = UNetDecoder(gpu_ids) + + def forward(self, input): + x, ind = self.encoder(input) + x = self.decoder(x, ind) + + return x + + +class ResnetEncoder(nn.Module): + """ + Encoder has the same structure as that of Resnet50,but the last 2 layers have been removed. + The shape of first channel has been changed, Resnet had 3 channels, but for this task we need 4 channels as we + are also adding the trimap + """ + + def __init__(self, id=50, pretrain=True, gpu_ids=[]): + super(ResnetEncoder, self).__init__() + print('Pretrain: {}'.format(pretrain)) + if id==50: + resnet = models.resnet50(pretrained=pretrain) + + modules = list(resnet.children())[:-2] # delete the last 2 layers. + for m in modules: + if 'MaxPool' in m.__class__.__name__: + m.return_indices = True + + + conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) + weights = torch.zeros(64, 4, 7, 7) + weights[:,:3,:,:] = modules[0].weight.data.view(64, 3, 7, 7) + conv1.weight.data.copy_(weights) + modules[0] = conv1 + + self.pool1 = nn.Sequential(*modules[: 4]) + self.resnet = nn.Sequential(*modules[4:]) + + + def forward(self, input): + x, ind = self.pool1(input) + + x = self.resnet(x) + + return x, ind + +class UNetDecoder(nn.Module): + + """ + The decoder network of the generator is same as that of the UNetDecoder. It + has seven upsampling convolutional blocks.Each upsampling convolutional block has an + upsampling layer followed by a convolutional layer, a batch normalization layer and a ReLU activation function + """ + def __init__(self, gpu_ids=[]): + super(UNetDecoder, self).__init__() + model = [nn.Conv2d(2048, 2048, kernel_size=1, padding=0), + nn.BatchNorm2d(2048), + nn.ReLU(True), + nn.ConvTranspose2d(2048, 1024, kernel_size=1, stride=2, output_padding=1, bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(True)] + model += [nn.Conv2d(1024, 1024, kernel_size=5, padding=2), + nn.BatchNorm2d(1024), + nn.ReLU(True), + nn.ConvTranspose2d(1024, 512, kernel_size=1, stride=2, output_padding=1, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(True)] + model += [nn.Conv2d(512, 512, kernel_size=5, padding=2), + nn.BatchNorm2d(512), + nn.ReLU(True), + nn.ConvTranspose2d(512, 256, kernel_size=1, stride=2, output_padding=1, bias=False), + # nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(True)] + model += [nn.Conv2d(256, 256, kernel_size=5, padding=2), + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.Conv2d(256, 64, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(True)] + model += [nn.Conv2d(64, 64, kernel_size=5, padding=2), + nn.BatchNorm2d(64), + nn.ReLU(True)] + self.model1 = nn.Sequential(*model) + self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) + model = [nn.Conv2d(64, 64, kernel_size=5, padding=2), + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d(64, 64, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(True)] + model += [nn.Conv2d(64, 64, kernel_size=5, padding=2), + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 1, kernel_size=5, padding=2), + nn.Sigmoid()] + self.model2 = nn.Sequential(*model) + + init_weights(self.model1, 'xavier') + init_weights(self.model2, 'xavier') + + def forward(self, input, ind): + x = self.model1(input) + x = self.unpool(x, ind) + x = self.model2(x) + + return x + + + +class ASPP_Module(nn.Module): + def __init__(self, input_maps, dilation_series, padding_series, output_maps): + super(ASPP_Module, self).__init__() + self.branches = nn.ModuleList() + self.branches.append(nn.Sequential(nn.Conv2d(input_maps, output_maps, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(output_maps, affine=affine_par))) + + for dilation, padding in zip(dilation_series, padding_series): + self.branches.append(nn.Sequential(nn.Conv2d(input_maps, output_maps, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True), + nn.BatchNorm2d(output_maps, affine=affine_par))) + + for m in self.branches: + m[0].weight.data.normal_(0, 0.01) + + image_level_features = [nn.AdaptiveAvgPool2d(1), + nn.Conv2d(input_maps, output_maps, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(output_maps, affine=affine_par)] + self.image_level_features = nn.Sequential(*image_level_features) + self.conv1x1 = nn.Conv2d(output_maps*(len(dilation_series)+2), output_maps, kernel_size=1, stride=1, bias=False) + self.bn1x1 = nn.BatchNorm2d(output_maps, affine=affine_par) + + def forward(self, x): + out = self.branches[0](x) + for i in range(len(self.branches)-1): + out = torch.cat([out, self.branches[i+1](x)], 1) + + image_features = nn.functional.upsample(self.image_level_features(x), size=(out.shape[2],out.shape[3]), mode='bilinear') + out = torch.cat([out, image_features], 1) + out = self.conv1x1(out) + out = self.bn1x1(out) + + return out + +class ResnetASPP(nn.Module): + def __init__(self, id=50, gpu_ids=[], pretrain=True): + super(ResnetX, self).__init__() + self.encoder = ResnetASPPEncoder(id, gpu_ids, pretrain) + self.decoder = UNetASPPDecoder(gpu_ids) + + def forward(self, input): + x, ind = self.encoder(input) + x = self.decoder(x, ind) + + return x + + + +class ResnetASPPEncoder(nn.Module): + """ + Encoder has the same structure as that of Resnet50,but the last 2 layers have been removed. + The shape of first channel has been changed, Resnet had 3 channels, but for this task we need 4 channels as we + are also adding the trimap. In this encoder , I have also added the ASPP module in the end. + """ + + def __init__(self, id=50, pretrain=True, gpu_ids=[]): + super(ResnetASPPEncoder, self).__init__() + print('Pretrain: {}'.format(pretrain)) + if id==50: + resnet = models.resnet50(pretrained=pretrain) + + modules = list(resnet.children())[:-2] # delete the last 2 layers. + for m in modules: + if 'MaxPool' in m.__class__.__name__: + m.return_indices = True + + + conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) + weights = torch.zeros(64, 4, 7, 7) + weights[:,:3,:,:] = modules[0].weight.data.view(64, 3, 7, 7) + conv1.weight.data.copy_(weights) + modules[0] = conv1 + + self.pool1 = nn.Sequential(*modules[: 4]) + self.resnet = nn.Sequential(*modules[4:]) + self.ASPP_layer = ASPP_Module(2048, [6,12,18],[6,12,18], 1024) + + def forward(self, input): + x, ind = self.pool1(input) + + x = self.resnet(x) + x = self.ASPP_layer(x) + + return x, ind + +class UNetASPPDecoder(nn.Module): + + """ + The decoder network of the generator is same as that of the UNetDecoder. It + has seven upsampling convolutional blocks.Each upsampling convolutional block has an + upsampling layer followed by a convolutional layer, a batch normalization layer and a ReLU activation function.The only + difference in this + """ + def __init__(self, gpu_ids=[]): + super(UNetASPPDecoder, self).__init__() + model = [nn.Conv2d(1024,1024, kernel_size=3, padding=1), + nn.BatchNorm2d(1024), + nn.ReLU(True), + nn.ConvTranspose2d(1024, 1024, kernel_size=1, stride=2, output_padding=1, bias=False), + nn.BatchNorm2d(1024), + nn.ReLU(True)] + model += [nn.Conv2d(1024, 1024, kernel_size=3, padding=1), + nn.BatchNorm2d(1024), + nn.ReLU(True), + nn.ConvTranspose2d(1024, 512, kernel_size=1, stride=2, output_padding=1, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(True)] + model += [nn.Conv2d(512, 512, kernel_size=5, padding=2), + nn.BatchNorm2d(512), + nn.ReLU(True), + nn.ConvTranspose2d(512, 256, kernel_size=1, stride=2, output_padding=1, bias=False), + # nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(True)] + model += [nn.Conv2d(256, 256, kernel_size=5, padding=2), + nn.BatchNorm2d(256), + nn.ReLU(True), + nn.Conv2d(256, 64, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(True)] + model += [nn.Conv2d(64, 64, kernel_size=5, padding=2), + nn.BatchNorm2d(64), + nn.ReLU(True)] + self.model1 = nn.Sequential(*model) + self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) + model = [nn.Conv2d(64, 64, kernel_size=5, padding=2), + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.ConvTranspose2d(64, 64, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(True)] + model += [nn.Conv2d(64, 64, kernel_size=5, padding=2), + nn.BatchNorm2d(64), + nn.ReLU(True), + nn.Conv2d(64, 1, kernel_size=5, padding=2), + nn.Sigmoid()] + self.model2 = nn.Sequential(*model) + + init_weights(self.model1, 'xavier') + init_weights(self.model2, 'xavier') + + def forward(self, input, ind): + x = self.model1(input) + x = self.unpool(x, ind) + x = self.model2(x) + + return x + + + + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): + super(NLayerDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True) + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] + + if use_sigmoid: + sequence += [nn.Sigmoid()] + + self.model = nn.Sequential(*sequence) + + def forward(self, input): + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + return nn.parallel.data_parallel(self.model, input, self.gpu_ids) + else: + return self.model(input) + +class PixelDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): + super(PixelDiscriminator, self).__init__() + self.gpu_ids = gpu_ids + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + if use_sigmoid: + self.net.append(nn.Sigmoid()) + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): + return nn.parallel.data_parallel(self.net, input, self.gpu_ids) + else: + return self.net(input) diff --git a/modules/ximgproc/src/alphagan_matting/models/simple_gan.py b/modules/ximgproc/src/alphagan_matting/models/simple_gan.py new file mode 100644 index 00000000000..4472557ee68 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/models/simple_gan.py @@ -0,0 +1,220 @@ +import numpy as np +import torch +import os +from collections import OrderedDict +from torch.autograd import Variable +import itertools +import util.util as util +from util.image_pool import ImagePool +from .base_model import BaseModel +from . import networks +import sys +import torch.nn as nn + +class SimpleModel(BaseModel): + def name(self): + return 'SimpleModel' + + def initialize(self, opt): + BaseModel.initialize(self, opt) + self.isTrain = opt.isTrain + + nb = opt.batchSize + size = opt.fineSize + + + #Number of input channels : 4(Image + trimap), Number of output channels : 1 + self.netG = networks.define_G(4, 1, + opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, pretrain=not opt.no_pretrain) + + if self.isTrain: + use_sigmoid = opt.no_lsgan + self.netD = networks.define_D(4, opt.ndf, + opt.which_model_netD, + opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) + + + + if not self.isTrain or opt.continue_train: + #Load a network + which_epoch = opt.which_epoch + self.load_network(self.netG, 'G', which_epoch) + if self.isTrain: + self.load_network(self.netD, 'D', opt.which_epoch) + + if self.isTrain: + + #We use an image pool , so that the discriminator does not forget what it did right/wrong before. + self.fake_AB_pool = ImagePool(opt.pool_size) + self.old_lr = opt.lr + # define loss functions + self.criterionAlpha = networks.AlphaPredictionLoss() + self.criterionComp = networks.CompLoss() + self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) + # initialize optimizers + #Using Adam optimizer for both discriminator and generator + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), + lr=opt.lr, betas=(opt.beta1, 0.999)) + self.optimizers = [] + self.schedulers = [] + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + for optimizer in self.optimizers: + self.schedulers.append(networks.get_scheduler(optimizer, opt)) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG) + if self.isTrain: + networks.print_network(self.netD) + print('-----------------------------------------------') + + def set_input(self, input): + AtoB = self.opt.which_direction == 'AtoB' + A_bg = input['A_bg'] + A_fg = input['A_fg'] + A_alpha = input['A_alpha'] + A_trimap = input['A_trimap'] + if len(self.gpu_ids) > 0: + A_bg = A_bg.cuda(self.gpu_ids[0], async=True) + A_fg = A_fg.cuda(self.gpu_ids[0], async=True) + A_alpha = A_alpha.cuda(self.gpu_ids[0], async=True) + A_trimap = A_trimap.cuda(self.gpu_ids[0], async=True) + self.bg_A = A_bg + self.fg_A = A_fg + self.alpha_A = A_alpha + self.trimap_A = A_trimap + #image is composed of the foreground and the background, using alpha matte. + self.img_A = self.composite(self.alpha_A, self.fg_A, self.bg_A) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def set_input_predict(self, input): + A_img = input['A_img'] + A_trimap = input['A_trimap'] + if len(self.gpu_ids) > 0: + A_img = A_img.cuda(self.gpu_ids[0], async=True) + A_trimap = A_trimap.cuda(self.gpu_ids[0], async=True) + self.A_trimap = A_trimap + self.A_img = A_img + #Our input is composed of the Image and the trimap. This is input to the generator. + self.input_A = torch.cat((self.A_img, self.A_trimap), 1) + self.image_paths = input['A_paths'] + + def composite(self, alpha, fg, bg): + img = torch.mul(alpha, fg) + torch.mul((1.0 - alpha), bg) + return img + + def trimap_merge(self, alpha, trimap): + + # Using the already known regions from trimap + final_alpha = torch.where(torch.eq(torch.ge(trimap, 0.4), torch.le(trimap, 0.6)), alpha, trimap) + return final_alpha + + def forward(self): + self.A_input = Variable(torch.cat((self.img_A, self.trimap_A), 1)) + self.A_fg = Variable(self.fg_A) + self.A_trimap = Variable(self.trimap_A) + self.A_bg = Variable(self.bg_A) + self.A_img = Variable(self.img_A) + self.A_alpha = Variable(self.alpha_A) + # self.A_disc = Variable(torch.cat((self.img_A, self.trimap_A, self.alpha_A), 1)) + + def predict(self): + self.netG.eval() + with torch.no_grad(): + self.real_A = Variable(self.A_img) + self.fake_B_alpha = self.netG(Variable(self.input_A)) + self.trimap_A = Variable(self.A_trimap) + self.fake_B = self.trimap_merge(self.fake_B_alpha, self.trimap_A) + + # get image paths + def get_image_paths(self): + return self.image_paths + + def backward_D_basic(self, netD, real, fake): + # Real + pred_real = netD(real) + loss_D_real = self.criterionGAN(pred_real, True) + # Fake + + pred_fake = netD(fake.detach()) + loss_D_fake = self.criterionGAN(pred_fake, False) + # Combined loss + loss_D = (loss_D_real + loss_D_fake) * 0.5 + # backward + loss_D.backward() + return loss_D + + def backward_D(self): + fake_comp = self.fake_AB_pool.query(self.comp_disc) + + loss_D = self.backward_D_basic(self.netD, self.A_input, fake_comp) + + self.loss_D = loss_D.data[0] + + def backward_G(self): + pred = self.netG(self.A_input) + pred = self.trimap_merge(pred, self.A_trimap) + comp = self.composite(pred, self.A_fg, self.A_bg) + + comp_disc = torch.cat((comp, self.A_trimap), 1) + + pred_fake = self.netD(comp_disc) + loss_g = self.criterionGAN(pred_fake, True) + + loss_a = self.criterionAlpha(pred, self.A_alpha, self.A_trimap) + + loss_c = 0 + + loss = loss_a + loss_c + loss_g + + loss.backward() + + self.pred = pred.data + self.comp_disc = comp_disc.data + + self.loss_a = loss_a.data[0] + #self.loss_c = loss_c.data[0] + self.loss_c = 0 + self.loss_g = loss_g.data[0] + + def optimize_parameters(self): + # forward + self.forward() + # G + self.optimizer_G.zero_grad() + self.backward_G() + self.optimizer_G.step() + # D + self.optimizer_D.zero_grad() + self.backward_D() + self.optimizer_D.step() + + def get_current_errors(self): + ret_errors = OrderedDict([('alpha_loss', self.loss_a), ('comp_loss', self.loss_c), ('gan_loss', self.loss_g), ('D', self.loss_D)]) + return ret_errors + + def get_current_visuals(self): + pred = util.tensor2im(self.pred) + gt = util.tensor2im(self.A_alpha) + img = util.tensor2im(self.A_img) + trimap = util.tensor2im(self.A_trimap) + bg = util.tensor2im(self.A_bg) + fg = util.tensor2im(self.A_fg) + + ret_visuals = OrderedDict([('img', img), ('trimap', trimap), ('pred', pred), ('gt', gt), ('fg', fg), ('bg', bg)]) + return ret_visuals + + def get_current_visuals_predict(self): + real_A = util.tensor2im(self.real_A.data) + trimap_A = util.tensor2im(self.trimap_A.data) + fake_B = util.tensor2im(self.fake_B.data) + fake_B_alpha = util.tensor2im(self.fake_B_alpha) + return OrderedDict([('real_A', real_A), ('trimap_A', trimap_A), ('fake_B_alpha', fake_B_alpha), ('fake_B', fake_B)]) + + + def save(self, label): + #saving seperately the weights of generator and discriminator + self.save_network(self.netG, 'G', label, self.gpu_ids) + self.save_network(self.netD, 'D', label, self.gpu_ids) diff --git a/modules/ximgproc/src/alphagan_matting/models/test_model.py b/modules/ximgproc/src/alphagan_matting/models/test_model.py new file mode 100644 index 00000000000..74b15de686c --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/models/test_model.py @@ -0,0 +1,61 @@ +from torch.autograd import Variable +from collections import OrderedDict +import util.util as util +from .base_model import BaseModel +from . import networks +import torch + + +class TestModel(BaseModel): + def name(self): + return 'TestModel' + + def initialize(self, opt): + assert(not opt.isTrain) + BaseModel.initialize(self, opt) + self.netG = networks.define_G(opt.input_nc, opt.output_nc, + opt.ngf, opt.which_model_netG, + opt.norm, not opt.no_dropout, + opt.init_type, + self.gpu_ids) + which_epoch = opt.which_epoch + self.load_network(self.netG, 'G', which_epoch) + + print('---------- Networks initialized -------------') + networks.print_network(self.netG) + print('-----------------------------------------------') + + def set_input(self, input): + # we need to use single_dataset mode + A_fg = input['A_fg'] + A_trimap = input['A_trimap'] + + if len(self.gpu_ids) > 0: + A_fg = A_fg.cuda(self.gpu_ids[0],async=True) + A_trimap = A_trimap.cuda(self.gpu_ids[0],async=True) + + self.trimap_A = A_trimap + self.fg_A = A_fg + self.input_A = torch.cat((self.fg_A,self.trimap_A),1) + self.image_paths = input['A_paths'] + + def test(self): + self.real_A = Variable(self.fg_A) + self.fake_B = self.trimap_merge(self.netG(self.real_A),self.trimap_A) + + + + # get image paths + def get_image_paths(self): + return self.image_paths + + def get_current_visuals(self): + real_A = util.tensor2im(self.real_A.data) + fake_B = util.tensor2im(self.fake_B.data) + return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) + + def trimap_merge(self, alpha, trimap): + + # Using the already known regions from trimap + final_alpha = torch.where(torch.eq(torch.ge(trimap, 0.4), torch.le(trimap, 0.6)), alpha, trimap) + return final_alpha diff --git a/modules/ximgproc/src/alphagan_matting/options/base_options.py b/modules/ximgproc/src/alphagan_matting/options/base_options.py new file mode 100644 index 00000000000..193ff48ebfa --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/options/base_options.py @@ -0,0 +1,75 @@ +import argparse +import os +from util import util +import torch + + +class BaseOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + self.initialized = False + + def initialize(self): + self.parser.add_argument('--dataroot', required=True, help='path to images') + self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') + self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') + self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') + self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') + self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') + self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') + self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') + self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD') + self.parser.add_argument('--which_model_netG', type=str, default='resnet50', help='selects model to use for netG') + self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') + self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') + self.parser.add_argument('--dataset_mode', type=str, default='custom', help='chooses how datasets are loaded. [custom | test]') + self.parser.add_argument('--model', type=str, default='simple', + help='chooses which model to use. simple, test') + + self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') + self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') + self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') + self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') + self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') + self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') + + self.initialized = True + + def parse(self): + if not self.initialized: + self.initialize() + self.opt = self.parser.parse_args() + self.opt.isTrain = self.isTrain # train or test + + str_ids = self.opt.gpu_ids.split(',') + self.opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + self.opt.gpu_ids.append(id) + + # set gpu ids + if len(self.opt.gpu_ids) > 0: + torch.cuda.set_device(self.opt.gpu_ids[0]) + + args = vars(self.opt) + + print('------------ Options -------------') + for k, v in sorted(args.items()): + print('%s: %s' % (str(k), str(v))) + print('-------------- End ----------------') + + # save to the disk + expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write('------------ Options -------------\n') + for k, v in sorted(args.items()): + opt_file.write('%s: %s\n' % (str(k), str(v))) + opt_file.write('-------------- End ----------------\n') + return self.opt diff --git a/modules/ximgproc/src/alphagan_matting/options/test_options.py b/modules/ximgproc/src/alphagan_matting/options/test_options.py new file mode 100644 index 00000000000..6b79860fd50 --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/options/test_options.py @@ -0,0 +1,13 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') + self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') + self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') + self.isTrain = False diff --git a/modules/ximgproc/src/alphagan_matting/options/train_options.py b/modules/ximgproc/src/alphagan_matting/options/train_options.py new file mode 100644 index 00000000000..f0ff541bc6c --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/options/train_options.py @@ -0,0 +1,29 @@ +from .base_options import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') + self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') + self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') + self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') + self.parser.add_argument('--save_epoch_freq', type=int, default=50, help='frequency of saving checkpoints at the end of epochs') + self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') + self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') + self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') + self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') + self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') + self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') + self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') + self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') + self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') + self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') + self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') + self.parser.add_argument('--identity', type=float, default=0.5, help='use identity mapping. Setting identity other than 1 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set optidentity = 0.1') + self.parser.add_argument('--no_pretrain', action='store_true', help='initialize weights with pretrained imagenet weights') + + self.isTrain = True diff --git a/modules/ximgproc/src/alphagan_matting/test.py b/modules/ximgproc/src/alphagan_matting/test.py new file mode 100644 index 00000000000..a6a7c25c2ef --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/test.py @@ -0,0 +1,36 @@ +import time +import os +from options.test_options import TestOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +from util.visualizer import Visualizer +from util import util + +opt = TestOptions().parse() +opt.nThreads = 1 # test code only supports nThreads = 1 +opt.batchSize = 1 # test code only supports batchSize = 1 +opt.serial_batches = True # no shuffle +opt.no_flip = True # no flip + +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +model = create_model(opt) + +for i, data in enumerate(dataset): + if i >= opt.how_many: + break + model.forward(data) + w = data['w'] + h = data['h'] + visuals = model.get_current_visuals() + img_path = model.get_image_paths() + print('%04d: process image... %s' % (i, img_path)) + aspect_ratio = opt.aspect_ratio + for label, im in visuals.items(): + image_name = '%s.png' % (name) + save_path = os.path.join(image_dir, image_name) + if aspect_ratio >= 1.0: + im = np.array(Image.fromarray(im).resize((h, int(w * aspect_ratio)))) + if aspect_ratio < 1.0: + im = np.array(Image.fromarray(im).resize((int(h/aspect_ratio),w))) + utils.save_image(im, save_path) diff --git a/modules/ximgproc/src/alphagan_matting/train.py b/modules/ximgproc/src/alphagan_matting/train.py new file mode 100644 index 00000000000..835940f548c --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/train.py @@ -0,0 +1,62 @@ +import time +from options.train_options import TrainOptions +from data.data_loader import CreateDataLoader +from models.models import create_model +from util.visualizer import Visualizer + +opt = TrainOptions().parse() +data_loader = CreateDataLoader(opt) +dataset = data_loader.load_data() +dataset_size = len(data_loader) +print('#training images = %d' % dataset_size) + +model = create_model(opt) +#visualizer = Visualizer(opt) +total_steps = 0 + +for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): + epoch_start_time = time.time() + epoch_iter = 0 + + for i, data in enumerate(dataset): + iter_start_time = time.time() + #visualizer.reset() + total_steps += opt.batchSize + epoch_iter += opt.batchSize + model.set_input(data) + #if opt.model == 'alpha_gan_merge': + #model.optimize_parameters(epoch) + #else: + model.optimize_parameters() + + #if total_steps % opt.display_freq == 0: + #save_result = total_steps % opt.update_html_freq == 0 + #if opt.model == 'alpha_gan_merge': + #visualizer.display_current_results(model.get_current_visuals(epoch), epoch, save_result) + #else: + #visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) + + if total_steps % opt.print_freq == 0: + #if opt.model == 'alpha_gan_merge': + #errors = model.get_current_errors(epoch) + #else: + errors = model.get_current_errors() + t = (time.time() - iter_start_time) / opt.batchSize + #visualizer.print_current_errors(epoch, epoch_iter, errors, t) + #if opt.display_id > 0: + #visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) + + if total_steps % opt.save_latest_freq == 0: + print('saving the latest model (epoch %d, total_steps %d)' % + (epoch, total_steps)) + model.save('latest') + + if epoch % opt.save_epoch_freq == 0: + print('saving the model at the end of epoch %d, iters %d' % + (epoch, total_steps)) + model.save('latest') + model.save(epoch) + + print('End of epoch %d / %d \t Time Taken: %d sec' % + (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) + model.update_learning_rate() diff --git a/modules/ximgproc/src/alphagan_matting/util/image_pool.py b/modules/ximgproc/src/alphagan_matting/util/image_pool.py new file mode 100644 index 00000000000..ada16271ffc --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/util/image_pool.py @@ -0,0 +1,34 @@ +import random +import numpy as np +import torch +from torch.autograd import Variable + + +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return Variable(images) + return_images = [] + for image in images: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size-1) + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = Variable(torch.cat(return_images, 0)) + return return_images diff --git a/modules/ximgproc/src/alphagan_matting/util/util.py b/modules/ximgproc/src/alphagan_matting/util/util.py new file mode 100644 index 00000000000..220b382b39f --- /dev/null +++ b/modules/ximgproc/src/alphagan_matting/util/util.py @@ -0,0 +1,61 @@ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import inspect +import re +import numpy as np +import os +import collections + + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8): + image_numpy = image_tensor[0].cpu().float().numpy() + if image_numpy.shape[0] == 1: + image_numpy = np.tile(image_numpy, (3, 1, 1)) + # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) From 43e929653b0bad02fe4ea2e2888498862cafb58b Mon Sep 17 00:00:00 2001 From: Vedanta Jha Date: Sun, 1 Sep 2019 11:03:41 +0530 Subject: [PATCH 2/5] Update models.py --- modules/ximgproc/src/alphagan_matting/models/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/ximgproc/src/alphagan_matting/models/models.py b/modules/ximgproc/src/alphagan_matting/models/models.py index b97c44fa163..0688e1d16bb 100644 --- a/modules/ximgproc/src/alphagan_matting/models/models.py +++ b/modules/ximgproc/src/alphagan_matting/models/models.py @@ -3,7 +3,8 @@ def create_model(opt): model = None print(opt.model) - + + if opt.model == 'simple': assert(opt.dataset_mode == 'generated_simple') from .simple_gan import SimpleModel model = SimpleModel() From 7648aa66fd8c2dff0b62fc613e20be99ce24a6f2 Mon Sep 17 00:00:00 2001 From: Vedanta Jha Date: Sun, 1 Sep 2019 11:04:17 +0530 Subject: [PATCH 3/5] Update custom_dataset_dataloader.py --- .../src/alphagan_matting/data/custom_dataset_dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/ximgproc/src/alphagan_matting/data/custom_dataset_dataloader.py b/modules/ximgproc/src/alphagan_matting/data/custom_dataset_dataloader.py index 5d243c6c57f..90cece8c247 100644 --- a/modules/ximgproc/src/alphagan_matting/data/custom_dataset_dataloader.py +++ b/modules/ximgproc/src/alphagan_matting/data/custom_dataset_dataloader.py @@ -5,7 +5,7 @@ def CreateDataset(opt): dataset = None - elif opt.dataset_mode == 'testData': + if opt.dataset_mode == 'testData': from data.test_dataset import TestDataset dataset = TestDataset() elif opt.dataset_mode == 'generated_simple': From b571515d2e174bcb477ea800bb5730366a8923b0 Mon Sep 17 00:00:00 2001 From: Vedanta Jha Date: Sun, 1 Sep 2019 11:10:54 +0530 Subject: [PATCH 4/5] Update networks.py --- modules/ximgproc/src/alphagan_matting/models/networks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/modules/ximgproc/src/alphagan_matting/models/networks.py b/modules/ximgproc/src/alphagan_matting/models/networks.py index f733a3083b9..a89ce40dd57 100644 --- a/modules/ximgproc/src/alphagan_matting/models/networks.py +++ b/modules/ximgproc/src/alphagan_matting/models/networks.py @@ -272,10 +272,10 @@ def forward(self, input, target, trimap): class CompLoss(nn.Module): - """ - Compostion Loss : Absolute difference between the ground truth RGB colors and the predicted RGB colors composited - by the groundtruth foreground, the ground truth background and the predicted alpha mattes - """ + """ + Compostion Loss : Absolute difference between the ground truth RGB colors and the predicted RGB colors composited + by the groundtruth foreground, the ground truth background and the predicted alpha mattes + """ def __init__(self): super(CompLoss, self).__init__() From 41d924ed4e3d9683a766ab5c7288936a0c99d233 Mon Sep 17 00:00:00 2001 From: Vedanta Jha Date: Sun, 1 Sep 2019 15:10:59 +0530 Subject: [PATCH 5/5] Removed whitespaces --- modules/ximgproc/src/alphagan_matting/models/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/ximgproc/src/alphagan_matting/models/models.py b/modules/ximgproc/src/alphagan_matting/models/models.py index 0688e1d16bb..2a0ee799c40 100644 --- a/modules/ximgproc/src/alphagan_matting/models/models.py +++ b/modules/ximgproc/src/alphagan_matting/models/models.py @@ -3,8 +3,8 @@ def create_model(opt): model = None print(opt.model) - - if opt.model == 'simple': + + if opt.model == 'simple': assert(opt.dataset_mode == 'generated_simple') from .simple_gan import SimpleModel model = SimpleModel()