Skip to content

[GSOC 2019]AlphaGAN Matting #2198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: 4.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modules/ximgproc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ Extended Image Processing
- Pei&Lin Normalization
- Ridge Detection Filter
- Binary morphology on run-length encoded images
- AlphaGan matting
93 changes: 93 additions & 0 deletions modules/ximgproc/src/alphagan_matting/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
## Architecture ##

#### Generator ####
A decoder encoder based architecture is used.


There are 2 options for the generator encoder.

<b>a.</b> Resnet50 minus the last 2 layers
<b>b.</b> Resnet50 + ASPP module

The Decoder network of the Generator network has seven upsampling convolutional blocks.
Each upsampling convolutional block has an upsampling layer, followed by a convolutional layer, a batch normalization layer and a ReLU activation function.

#### Discriminator ####
The discriminator used here is the PatchGAN discriminator. The implementation here is inspired from the implementation of CycleGAN from<br/>
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

Again, 2 different types of discriminator are used
<b>a.</b> N Layer Patch gan discriminator, where the size of the patch is NxN, it is taken as 3x3 here
<b>b.</b> Pixel patch Patch gan discriminator, the discriminator classsifies every pixel.

## How to use
Use the dataroot argument to enter the directory where you have stored the data.
Structure the data in the following way.

train<br/>
-alpha -bg -fg

test<br/>
-fg -trimap

The background I have used here is the MSCOCO dataset.


To train the model using Resnet50 without ASPP module

`!python train.py --dataroot ./ --model simple --dataset_mode generated_simple --which_model_netG resnet50 --name resnet50`

To test the model using Resnet without ASPP module

`!python test.py --dataroot ./ --dataset_mode single --which_model_netG resnet50 --ntest 8 --model test --name resnet50`

To train the model using Resnet50 using ASPP module

`!python train.py --dataroot ./ --model simple --dataset_mode generated_simple --which_model_netG resnet50ASPP --name resnet50ASPP`

To test the model using Resnet50 using ASPP module

`!python test.py --dataroot ./ --dataset_mode single --which_model_netG resnet50ASPP --ntest 8 --model test`

## Results

#### Input:
![input](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/2.png)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are documentation images are stored outside?


#### Trimap:
![Trimap](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey_tri.png)


#### AlphaGAN matting :
##### Generator:Resnet50,Discriminator:N Layer Patch GAN
![Output2](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey_resnet50.png)

##### Generator:Resnet50,Discriminator:Pixel Patch GAN
![Output3](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey.png)

##### Generator:Resnet50 + ASPP module,Discriminator:N Layer Patch GAN
![Output4](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey_deeplab.png)


### Comparing with the Original implementation
(Average Rank on alphamatting.com has been shown)

| Error type | Original implementation | Resnet50 +N Layer | Resnet50 + Pixel | Resnet50 + ASPP module |
| ----------- | ------------------------ | ------------------- | ----------------- | -----------|
| Sum of absolute differences | 11.7 | 42.8 | 43.8 | 53 |
| Mean square error | 15 | 45.8 | 45.6 | 54.2 |
| Gradient error | 14 | 52.9 | 52.7 | 55 |
| Connectivity error | 29.6 | 23.3 | 22.6 | 32.8 |


### Training dataset used
I used the training dataset created by me using the software known as gimp.
[Link to created dataset](https://drive.google.com/open?id=1zQbk2Cu7QOBwzg4vVGqCWJwHGTwGppFe)

### What could be wrong ?

1.The dataset I created is not perfect. I tried to make it as perfect by marking every pixel in some cases as well, but it still ain't perfect.

2. The output of the generator is 320x320, then the alpha matte is resized to the original image. Maybe there is a loss in resizing the image, and a better upsampling method might just imporove the outputs.

3. The major reason the implementation isn't as good as original is that the author hasn't clearly mentioned how the skip connections are used. I am very sure about the architecture of the encoder and decoder, but the only thing I am unsure about is how skip connections are used.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass

def load_data():
return None
45 changes: 45 additions & 0 deletions modules/ximgproc/src/alphagan_matting/data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms

class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()

def name(self):
return 'BaseDataset'

def initialize(self, opt):
pass

def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, Image.BICUBIC))
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'scale_width':
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.fineSize)))
elif opt.resize_or_crop == 'scale_width_and_crop':
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.loadSize)))
transform_list.append(transforms.RandomCrop(opt.fineSize))

if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())

transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)

def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), Image.BICUBIC)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch.utils.data
from data.base_data_loader import BaseDataLoader


def CreateDataset(opt):
dataset = None

if opt.dataset_mode == 'testData':
from data.test_dataset import TestDataset
dataset = TestDataset()
elif opt.dataset_mode == 'generated_simple':
from data.generated_dataset_simple import GeneratedDatasetSimple
dataset = GeneratedDatasetSimple()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)

print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset


class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'

def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))

def load_data(self):
return self

def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)

def __iter__(self):
for i, data in enumerate(self.dataloader):
if i >= self.opt.max_dataset_size:
break
yield data
6 changes: 6 additions & 0 deletions modules/ximgproc/src/alphagan_matting/data/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
189 changes: 189 additions & 0 deletions modules/ximgproc/src/alphagan_matting/data/generated_dataset_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import os.path
import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import PIL
import random
import scipy.ndimage
import numpy as np
import math
# import pbcvt
# import colour_transfer

class GeneratedDatasetSimple(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
self.dir_alpha = os.path.join(self.dir_AB, 'alpha')
self.dir_fg = os.path.join(self.dir_AB, 'fg')
self.dir_bg = os.path.join(self.dir_AB, 'bg')
self.alpha_paths = sorted(make_dataset(self.dir_alpha))
self.fg_paths = sorted(make_dataset(self.dir_fg))
self.bg_paths = make_dataset(self.dir_bg)
self.alpha_size = len(self.alpha_paths)
self.bg_size = len(self.bg_paths)


def __getitem__(self, index):
index = index % self.alpha_size
alpha_path = self.alpha_paths[index]
fg_path = self.fg_paths[index]
index_bg = random.randint(0, self.bg_size - 1)
bg_path = self.bg_paths[index_bg]


A_bg = Image.open(bg_path).convert('RGB')
A_fg = Image.open(fg_path).convert('RGB')

A_alpha = Image.open(alpha_path).convert('L')
assert A_alpha.mode == 'L'


A_trimap = self.generate_trimap(A_alpha)

# A_bg = self.resize_bg(A_bg, A_fg)
w_bg, h_bg = A_bg.size
if w_bg < 321 or h_bg < 321:
x = w_bg if w_bg < h_bg else h_bg
ratio = 321/float(x)
A_bg = A_bg.resize((int(np.ceil(w_bg*ratio)+1),int(np.ceil(h_bg*ratio)+1)), Image.BICUBIC)
w_bg, h_bg = A_bg.size
assert w_bg > 320 and h_bg > 320, '{} {}'.format(w_bg, h_bg)
x = random.randint(0, w_bg-320-1)
y = random.randint(0, h_bg-320-1)
A_bg = A_bg.crop((x,y, x+320, y+320))

crop_size = random.choice([320,480,640])
# crop_size = random.choice([320,400,480,560,640,720])
crop_center = self.find_crop_center(A_trimap)
start_index_height = max(min(A_fg.size[1]-crop_size, crop_center[0] - int(crop_size/2) + 1), 0)
start_index_width = max(min(A_fg.size[0]-crop_size, crop_center[1] - int(crop_size/2) + 1), 0)

bbox = ((start_index_width,start_index_height,start_index_width+crop_size,start_index_height+crop_size))

# A_bg = A_bg.crop(bbox)
A_fg = A_fg.crop(bbox)
A_alpha = A_alpha.crop(bbox)
A_trimap = A_trimap.crop(bbox)

if self.opt.which_model_netG == 'unet_256':
A_bg = A_bg.resize((256,256))
A_fg = A_fg.resize((256,256))
A_alpha = A_alpha.resize((256,256))
A_trimap = A_trimap.resize((256,256))
assert A_alpha.mode == 'L'
else:
A_bg = A_bg.resize((320,320))
A_fg = A_fg.resize((320,320))
A_alpha = A_alpha.resize((320,320))
A_trimap = A_trimap.resize((320,320))
assert A_alpha.mode == 'L'

if random.randint(0, 1):
A_bg = A_bg.transpose(PIL.Image.FLIP_LEFT_RIGHT)

if random.randint(0, 1):
A_fg = A_fg.transpose(PIL.Image.FLIP_LEFT_RIGHT)
A_alpha = A_alpha.transpose(PIL.Image.FLIP_LEFT_RIGHT)
A_trimap = A_trimap.transpose(PIL.Image.FLIP_LEFT_RIGHT)

## COLOR TRANSFER ##
# if random.randint(0, 2) != 0:
# A_old = A_fg
# target = np.array(A_fg)
# palette = np.array(A_palette)
# recolor = colour_transfer.runCT(target, palette)
# A_fg = Image.fromarray(recolor)

if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
else:
input_nc = self.opt.input_nc
output_nc = self.opt.output_nc

A_bg = transforms.ToTensor()(A_bg)
A_fg = transforms.ToTensor()(A_fg)
A_alpha = transforms.ToTensor()(A_alpha)
A_trimap = transforms.ToTensor()(A_trimap)

return {'A_bg': A_bg,
'A_fg': A_fg,
'A_alpha': A_alpha,
'A_trimap': A_trimap,
'A_paths': alpha_path}

def resize_bg(self, bg, fg):
bbox = fg.size
w = bbox[0]
h = bbox[1]
bg_bbox = bg.size
bw = bg_bbox[0]
bh = bg_bbox[1]
wratio = w / float(bw)
hratio = h / float(bh)
ratio = wratio if wratio > hratio else hratio
if ratio > 1:
bg = bg.resize((int(np.ceil(bw*ratio)+1),int(np.ceil(bh*ratio)+1)), Image.BICUBIC)
bg = bg.crop((0,0,w,h))

return bg

# def generate_trimap(self, alpha):
# trimap = np.array(alpha)
# kernel_sizes = [val for val in range(5,40)]
# kernel = random.choice(kernel_sizes)
# trimap[np.where((scipy.ndimage.grey_dilation(alpha,size=(kernel,kernel)) - alpha!=0))] = 128

# return Image.fromarray(trimap)
def generate_trimap(self, alpha):
trimap = np.array(alpha)
grey = np.zeros_like(trimap)
kernel_sizes = [val for val in range(2,20)]
kernel = random.choice(kernel_sizes)
# trimap[np.where((scipy.ndimage.grey_dilation(alpha,size=(kernel,kernel)) - alpha!=0))] = 128
grey = np.where(np.logical_and(trimap>0, trimap<255), 128, 0)
grey = scipy.ndimage.grey_dilation(grey, size=(kernel,kernel))
trimap[grey==128] = 128

return Image.fromarray(trimap)

def find_crop_center(self, trimap):
t = np.array(trimap)
target = np.where(t==128)
index = random.choice([i for i in range(len(target[0]))])
return np.array(target)[:,index][:2]

def rotatedRectWithMaxArea(self, w, h, angle):
"""
Given a rectangle of size wxh that has been rotated by 'angle' (in
radians), computes the width and height of the largest possible
axis-aligned rectangle (maximal area) within the rotated rectangle.
"""
if w <= 0 or h <= 0:
return 0,0
width_is_longer = w >= h
side_long, side_short = (w,h) if width_is_longer else (h,w)

# since the solutions for angle, -angle and 180-angle are all the same,
# if suffices to look at the first quadrant and the absolute values of sin,cos:
sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle))
if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10:
# half constrained case: two crop corners touch the longer side,
# the other two corners are on the mid-line parallel to the longer line
x = 0.5*side_short
wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a)
else:
# fully constrained case: crop touches all 4 sides
cos_2a = cos_a*cos_a - sin_a*sin_a
wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a

return wr,hr

def __len__(self):
return len(self.alpha_paths)

def name(self):
return 'GeneratedDataset'
Loading