-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[GSOC 2019]AlphaGAN Matting #2198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Nerdyvedi
wants to merge
5
commits into
opencv:4.x
Choose a base branch
from
Nerdyvedi:alphaGAN
base: 4.x
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
## Architecture ## | ||
|
||
#### Generator #### | ||
A decoder encoder based architecture is used. | ||
|
||
|
||
There are 2 options for the generator encoder. | ||
|
||
<b>a.</b> Resnet50 minus the last 2 layers | ||
<b>b.</b> Resnet50 + ASPP module | ||
|
||
The Decoder network of the Generator network has seven upsampling convolutional blocks. | ||
Each upsampling convolutional block has an upsampling layer, followed by a convolutional layer, a batch normalization layer and a ReLU activation function. | ||
|
||
#### Discriminator #### | ||
The discriminator used here is the PatchGAN discriminator. The implementation here is inspired from the implementation of CycleGAN from<br/> | ||
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix | ||
|
||
Again, 2 different types of discriminator are used | ||
<b>a.</b> N Layer Patch gan discriminator, where the size of the patch is NxN, it is taken as 3x3 here | ||
<b>b.</b> Pixel patch Patch gan discriminator, the discriminator classsifies every pixel. | ||
|
||
## How to use | ||
Use the dataroot argument to enter the directory where you have stored the data. | ||
Structure the data in the following way. | ||
|
||
train<br/> | ||
-alpha -bg -fg | ||
|
||
test<br/> | ||
-fg -trimap | ||
|
||
The background I have used here is the MSCOCO dataset. | ||
|
||
|
||
To train the model using Resnet50 without ASPP module | ||
|
||
`!python train.py --dataroot ./ --model simple --dataset_mode generated_simple --which_model_netG resnet50 --name resnet50` | ||
|
||
To test the model using Resnet without ASPP module | ||
|
||
`!python test.py --dataroot ./ --dataset_mode single --which_model_netG resnet50 --ntest 8 --model test --name resnet50` | ||
|
||
To train the model using Resnet50 using ASPP module | ||
|
||
`!python train.py --dataroot ./ --model simple --dataset_mode generated_simple --which_model_netG resnet50ASPP --name resnet50ASPP` | ||
|
||
To test the model using Resnet50 using ASPP module | ||
|
||
`!python test.py --dataroot ./ --dataset_mode single --which_model_netG resnet50ASPP --ntest 8 --model test` | ||
|
||
## Results | ||
|
||
#### Input: | ||
 | ||
|
||
#### Trimap: | ||
 | ||
|
||
|
||
#### AlphaGAN matting : | ||
##### Generator:Resnet50,Discriminator:N Layer Patch GAN | ||
 | ||
|
||
##### Generator:Resnet50,Discriminator:Pixel Patch GAN | ||
 | ||
|
||
##### Generator:Resnet50 + ASPP module,Discriminator:N Layer Patch GAN | ||
 | ||
|
||
|
||
### Comparing with the Original implementation | ||
(Average Rank on alphamatting.com has been shown) | ||
|
||
| Error type | Original implementation | Resnet50 +N Layer | Resnet50 + Pixel | Resnet50 + ASPP module | | ||
| ----------- | ------------------------ | ------------------- | ----------------- | -----------| | ||
| Sum of absolute differences | 11.7 | 42.8 | 43.8 | 53 | | ||
| Mean square error | 15 | 45.8 | 45.6 | 54.2 | | ||
| Gradient error | 14 | 52.9 | 52.7 | 55 | | ||
| Connectivity error | 29.6 | 23.3 | 22.6 | 32.8 | | ||
|
||
|
||
### Training dataset used | ||
I used the training dataset created by me using the software known as gimp. | ||
[Link to created dataset](https://drive.google.com/open?id=1zQbk2Cu7QOBwzg4vVGqCWJwHGTwGppFe) | ||
|
||
### What could be wrong ? | ||
|
||
1.The dataset I created is not perfect. I tried to make it as perfect by marking every pixel in some cases as well, but it still ain't perfect. | ||
|
||
2. The output of the generator is 320x320, then the alpha matte is resized to the original image. Maybe there is a loss in resizing the image, and a better upsampling method might just imporove the outputs. | ||
|
||
3. The major reason the implementation isn't as good as original is that the author hasn't clearly mentioned how the skip connections are used. I am very sure about the architecture of the encoder and decoder, but the only thing I am unsure about is how skip connections are used. |
9 changes: 9 additions & 0 deletions
9
modules/ximgproc/src/alphagan_matting/data/base_data_loader.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
class BaseDataLoader(): | ||
def __init__(self): | ||
pass | ||
def initialize(self, opt): | ||
self.opt = opt | ||
pass | ||
|
||
def load_data(): | ||
return None |
45 changes: 45 additions & 0 deletions
45
modules/ximgproc/src/alphagan_matting/data/base_dataset.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch.utils.data as data | ||
from PIL import Image | ||
import torchvision.transforms as transforms | ||
|
||
class BaseDataset(data.Dataset): | ||
def __init__(self): | ||
super(BaseDataset, self).__init__() | ||
|
||
def name(self): | ||
return 'BaseDataset' | ||
|
||
def initialize(self, opt): | ||
pass | ||
|
||
def get_transform(opt): | ||
transform_list = [] | ||
if opt.resize_or_crop == 'resize_and_crop': | ||
osize = [opt.loadSize, opt.loadSize] | ||
transform_list.append(transforms.Scale(osize, Image.BICUBIC)) | ||
transform_list.append(transforms.RandomCrop(opt.fineSize)) | ||
elif opt.resize_or_crop == 'crop': | ||
transform_list.append(transforms.RandomCrop(opt.fineSize)) | ||
elif opt.resize_or_crop == 'scale_width': | ||
transform_list.append(transforms.Lambda( | ||
lambda img: __scale_width(img, opt.fineSize))) | ||
elif opt.resize_or_crop == 'scale_width_and_crop': | ||
transform_list.append(transforms.Lambda( | ||
lambda img: __scale_width(img, opt.loadSize))) | ||
transform_list.append(transforms.RandomCrop(opt.fineSize)) | ||
|
||
if opt.isTrain and not opt.no_flip: | ||
transform_list.append(transforms.RandomHorizontalFlip()) | ||
|
||
transform_list += [transforms.ToTensor(), | ||
transforms.Normalize((0.5, 0.5, 0.5), | ||
(0.5, 0.5, 0.5))] | ||
return transforms.Compose(transform_list) | ||
|
||
def __scale_width(img, target_width): | ||
ow, oh = img.size | ||
if (ow == target_width): | ||
return img | ||
w = target_width | ||
h = int(target_width * oh / ow) | ||
return img.resize((w, h), Image.BICUBIC) |
45 changes: 45 additions & 0 deletions
45
modules/ximgproc/src/alphagan_matting/data/custom_dataset_dataloader.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch.utils.data | ||
from data.base_data_loader import BaseDataLoader | ||
|
||
|
||
def CreateDataset(opt): | ||
dataset = None | ||
|
||
if opt.dataset_mode == 'testData': | ||
from data.test_dataset import TestDataset | ||
dataset = TestDataset() | ||
elif opt.dataset_mode == 'generated_simple': | ||
from data.generated_dataset_simple import GeneratedDatasetSimple | ||
dataset = GeneratedDatasetSimple() | ||
else: | ||
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) | ||
|
||
print("dataset [%s] was created" % (dataset.name())) | ||
dataset.initialize(opt) | ||
return dataset | ||
|
||
|
||
class CustomDatasetDataLoader(BaseDataLoader): | ||
def name(self): | ||
return 'CustomDatasetDataLoader' | ||
|
||
def initialize(self, opt): | ||
BaseDataLoader.initialize(self, opt) | ||
self.dataset = CreateDataset(opt) | ||
self.dataloader = torch.utils.data.DataLoader( | ||
self.dataset, | ||
batch_size=opt.batchSize, | ||
shuffle=not opt.serial_batches, | ||
num_workers=int(opt.nThreads)) | ||
|
||
def load_data(self): | ||
return self | ||
|
||
def __len__(self): | ||
return min(len(self.dataset), self.opt.max_dataset_size) | ||
|
||
def __iter__(self): | ||
for i, data in enumerate(self.dataloader): | ||
if i >= self.opt.max_dataset_size: | ||
break | ||
yield data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
def CreateDataLoader(opt): | ||
from data.custom_dataset_data_loader import CustomDatasetDataLoader | ||
data_loader = CustomDatasetDataLoader() | ||
print(data_loader.name()) | ||
data_loader.initialize(opt) | ||
return data_loader |
189 changes: 189 additions & 0 deletions
189
modules/ximgproc/src/alphagan_matting/data/generated_dataset_simple.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
import os.path | ||
import torchvision.transforms as transforms | ||
from data.base_dataset import BaseDataset, get_transform | ||
from data.image_folder import make_dataset | ||
from PIL import Image | ||
import PIL | ||
import random | ||
import scipy.ndimage | ||
import numpy as np | ||
import math | ||
# import pbcvt | ||
# import colour_transfer | ||
|
||
class GeneratedDatasetSimple(BaseDataset): | ||
def initialize(self, opt): | ||
self.opt = opt | ||
self.root = opt.dataroot | ||
self.dir_AB = os.path.join(opt.dataroot, opt.phase) | ||
self.dir_alpha = os.path.join(self.dir_AB, 'alpha') | ||
self.dir_fg = os.path.join(self.dir_AB, 'fg') | ||
self.dir_bg = os.path.join(self.dir_AB, 'bg') | ||
self.alpha_paths = sorted(make_dataset(self.dir_alpha)) | ||
self.fg_paths = sorted(make_dataset(self.dir_fg)) | ||
self.bg_paths = make_dataset(self.dir_bg) | ||
self.alpha_size = len(self.alpha_paths) | ||
self.bg_size = len(self.bg_paths) | ||
|
||
|
||
def __getitem__(self, index): | ||
index = index % self.alpha_size | ||
alpha_path = self.alpha_paths[index] | ||
fg_path = self.fg_paths[index] | ||
index_bg = random.randint(0, self.bg_size - 1) | ||
bg_path = self.bg_paths[index_bg] | ||
|
||
|
||
A_bg = Image.open(bg_path).convert('RGB') | ||
A_fg = Image.open(fg_path).convert('RGB') | ||
|
||
A_alpha = Image.open(alpha_path).convert('L') | ||
assert A_alpha.mode == 'L' | ||
|
||
|
||
A_trimap = self.generate_trimap(A_alpha) | ||
|
||
# A_bg = self.resize_bg(A_bg, A_fg) | ||
w_bg, h_bg = A_bg.size | ||
if w_bg < 321 or h_bg < 321: | ||
x = w_bg if w_bg < h_bg else h_bg | ||
ratio = 321/float(x) | ||
A_bg = A_bg.resize((int(np.ceil(w_bg*ratio)+1),int(np.ceil(h_bg*ratio)+1)), Image.BICUBIC) | ||
w_bg, h_bg = A_bg.size | ||
assert w_bg > 320 and h_bg > 320, '{} {}'.format(w_bg, h_bg) | ||
x = random.randint(0, w_bg-320-1) | ||
y = random.randint(0, h_bg-320-1) | ||
A_bg = A_bg.crop((x,y, x+320, y+320)) | ||
|
||
crop_size = random.choice([320,480,640]) | ||
# crop_size = random.choice([320,400,480,560,640,720]) | ||
crop_center = self.find_crop_center(A_trimap) | ||
start_index_height = max(min(A_fg.size[1]-crop_size, crop_center[0] - int(crop_size/2) + 1), 0) | ||
start_index_width = max(min(A_fg.size[0]-crop_size, crop_center[1] - int(crop_size/2) + 1), 0) | ||
|
||
bbox = ((start_index_width,start_index_height,start_index_width+crop_size,start_index_height+crop_size)) | ||
|
||
# A_bg = A_bg.crop(bbox) | ||
A_fg = A_fg.crop(bbox) | ||
A_alpha = A_alpha.crop(bbox) | ||
A_trimap = A_trimap.crop(bbox) | ||
|
||
if self.opt.which_model_netG == 'unet_256': | ||
A_bg = A_bg.resize((256,256)) | ||
A_fg = A_fg.resize((256,256)) | ||
A_alpha = A_alpha.resize((256,256)) | ||
A_trimap = A_trimap.resize((256,256)) | ||
assert A_alpha.mode == 'L' | ||
else: | ||
A_bg = A_bg.resize((320,320)) | ||
A_fg = A_fg.resize((320,320)) | ||
A_alpha = A_alpha.resize((320,320)) | ||
A_trimap = A_trimap.resize((320,320)) | ||
assert A_alpha.mode == 'L' | ||
|
||
if random.randint(0, 1): | ||
A_bg = A_bg.transpose(PIL.Image.FLIP_LEFT_RIGHT) | ||
|
||
if random.randint(0, 1): | ||
A_fg = A_fg.transpose(PIL.Image.FLIP_LEFT_RIGHT) | ||
A_alpha = A_alpha.transpose(PIL.Image.FLIP_LEFT_RIGHT) | ||
A_trimap = A_trimap.transpose(PIL.Image.FLIP_LEFT_RIGHT) | ||
|
||
## COLOR TRANSFER ## | ||
# if random.randint(0, 2) != 0: | ||
# A_old = A_fg | ||
# target = np.array(A_fg) | ||
# palette = np.array(A_palette) | ||
# recolor = colour_transfer.runCT(target, palette) | ||
# A_fg = Image.fromarray(recolor) | ||
|
||
if self.opt.which_direction == 'BtoA': | ||
input_nc = self.opt.output_nc | ||
output_nc = self.opt.input_nc | ||
else: | ||
input_nc = self.opt.input_nc | ||
output_nc = self.opt.output_nc | ||
|
||
A_bg = transforms.ToTensor()(A_bg) | ||
A_fg = transforms.ToTensor()(A_fg) | ||
A_alpha = transforms.ToTensor()(A_alpha) | ||
A_trimap = transforms.ToTensor()(A_trimap) | ||
|
||
return {'A_bg': A_bg, | ||
'A_fg': A_fg, | ||
'A_alpha': A_alpha, | ||
'A_trimap': A_trimap, | ||
'A_paths': alpha_path} | ||
|
||
def resize_bg(self, bg, fg): | ||
bbox = fg.size | ||
w = bbox[0] | ||
h = bbox[1] | ||
bg_bbox = bg.size | ||
bw = bg_bbox[0] | ||
bh = bg_bbox[1] | ||
wratio = w / float(bw) | ||
hratio = h / float(bh) | ||
ratio = wratio if wratio > hratio else hratio | ||
if ratio > 1: | ||
bg = bg.resize((int(np.ceil(bw*ratio)+1),int(np.ceil(bh*ratio)+1)), Image.BICUBIC) | ||
bg = bg.crop((0,0,w,h)) | ||
|
||
return bg | ||
|
||
# def generate_trimap(self, alpha): | ||
# trimap = np.array(alpha) | ||
# kernel_sizes = [val for val in range(5,40)] | ||
# kernel = random.choice(kernel_sizes) | ||
# trimap[np.where((scipy.ndimage.grey_dilation(alpha,size=(kernel,kernel)) - alpha!=0))] = 128 | ||
|
||
# return Image.fromarray(trimap) | ||
def generate_trimap(self, alpha): | ||
trimap = np.array(alpha) | ||
grey = np.zeros_like(trimap) | ||
kernel_sizes = [val for val in range(2,20)] | ||
kernel = random.choice(kernel_sizes) | ||
# trimap[np.where((scipy.ndimage.grey_dilation(alpha,size=(kernel,kernel)) - alpha!=0))] = 128 | ||
grey = np.where(np.logical_and(trimap>0, trimap<255), 128, 0) | ||
grey = scipy.ndimage.grey_dilation(grey, size=(kernel,kernel)) | ||
trimap[grey==128] = 128 | ||
|
||
return Image.fromarray(trimap) | ||
|
||
def find_crop_center(self, trimap): | ||
t = np.array(trimap) | ||
target = np.where(t==128) | ||
index = random.choice([i for i in range(len(target[0]))]) | ||
return np.array(target)[:,index][:2] | ||
|
||
def rotatedRectWithMaxArea(self, w, h, angle): | ||
""" | ||
Given a rectangle of size wxh that has been rotated by 'angle' (in | ||
radians), computes the width and height of the largest possible | ||
axis-aligned rectangle (maximal area) within the rotated rectangle. | ||
""" | ||
if w <= 0 or h <= 0: | ||
return 0,0 | ||
width_is_longer = w >= h | ||
side_long, side_short = (w,h) if width_is_longer else (h,w) | ||
|
||
# since the solutions for angle, -angle and 180-angle are all the same, | ||
# if suffices to look at the first quadrant and the absolute values of sin,cos: | ||
sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) | ||
if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10: | ||
# half constrained case: two crop corners touch the longer side, | ||
# the other two corners are on the mid-line parallel to the longer line | ||
x = 0.5*side_short | ||
wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a) | ||
else: | ||
# fully constrained case: crop touches all 4 sides | ||
cos_2a = cos_a*cos_a - sin_a*sin_a | ||
wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a | ||
|
||
return wr,hr | ||
|
||
def __len__(self): | ||
return len(self.alpha_paths) | ||
|
||
def name(self): | ||
return 'GeneratedDataset' |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are documentation images are stored outside?