Skip to content

Commit b8e7b28

Browse files
committed
Add files via upload
Update README.md Add files via upload Delete epoch132_gt.png Delete epoch132_img.png Delete epoch132_pred.png Add files via upload Update README.md Update README.md Update README.md Update README.md Update README.md Delete image_pool.py Delete network.py Delete train.py Delete base_options.py Delete train_options.py Delete pred6.png Delete epoch138_pred.png Delete epoch132_gt.png Delete epoch132_img.png Delete epoch132_pred.png Delete epoch134_gt.png Delete epoch134_img.png Delete epoch134_pred.png Delete epoch138_gt.png Delete epoch138_img.png Delete pred5.png Delete gt1.png Delete gt2.png Delete gt3.png Delete gt4.png Delete gt5.png Delete gt6.png Delete img1.png Delete img2.png Delete img3.png Delete img4.png Delete img5.png Delete img6.png Delete pred1.png Delete pred2.png Delete pred3.png Delete pred4.png Delete README.md Create train.py Add files via upload Delete single_dataset.py Add files via upload Delete test_model.py Add files via upload Update README.md Update networks.py Delete deeplabv3.py Update custom_dataset_dataloader.py Create README.md Update README.md Update custom_dataset_dataloader.py Update models.py Update custom_dataset_dataloader.py Update test.py Update test.py Update test.py Rename modules/ximgproc/src/alphagan_matting/utils/image_pool.py to modules/ximgproc/src/alphagan_matting/util/image_pool.py Rename modules/ximgproc/src/alphagan_matting/utils/utils.py to modules/ximgproc/src/alphagan_matting/util/util.py Update networks.py Update test_dataset.py Add files via upload Add files via upload Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update README.md Update base_data_loader.py Removing whitespaces Removing whitespaces Removing whitespaces Removing whitespaces Removing whitespaces Removing whitespaces
1 parent 0915b7e commit b8e7b28

21 files changed

+1808
-0
lines changed

modules/ximgproc/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ Extended Image Processing
1616
- Pei&Lin Normalization
1717
- Ridge Detection Filter
1818
- Binary morphology on run-length encoded images
19+
- AlphaGan matting
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
## Architecture ##
2+
3+
#### Generator ####
4+
A decoder encoder based architecture is used.
5+
6+
7+
There are 2 options for the generator encoder.
8+
9+
<b>a.</b> Resnet50 minus the last 2 layers
10+
<b>b.</b> Resnet50 + ASPP module
11+
12+
The Decoder network of the Generator network has seven upsampling convolutional blocks.
13+
Each upsampling convolutional block has an upsampling layer, followed by a convolutional layer, a batch normalization layer and a ReLU activation function.
14+
15+
#### Discriminator ####
16+
The discriminator used here is the PatchGAN discriminator. The implementation here is inspired from the implementation of CycleGAN from<br/>
17+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
18+
19+
Again, 2 different types of discriminator are used
20+
<b>a.</b> N Layer Patch gan discriminator, where the size of the patch is NxN, it is taken as 3x3 here
21+
<b>b.</b> Pixel patch Patch gan discriminator, the discriminator classsifies every pixel.
22+
23+
## How to use
24+
Use the dataroot argument to enter the directory where you have stored the data.
25+
Structure the data in the following way.
26+
27+
train<br/>
28+
-alpha -bg -fg
29+
30+
test<br/>
31+
-fg -trimap
32+
33+
The background I have used here is the MSCOCO dataset.
34+
35+
36+
To train the model using Resnet50 without ASPP module
37+
38+
`!python train.py --dataroot ./ --model simple --dataset_mode generated_simple --which_model_netG resnet50 --name resnet50`
39+
40+
To test the model using Resnet without ASPP module
41+
42+
`!python test.py --dataroot ./ --dataset_mode single --which_model_netG resnet50 --ntest 8 --model test --name resnet50`
43+
44+
To train the model using Resnet50 using ASPP module
45+
46+
`!python train.py --dataroot ./ --model simple --dataset_mode generated_simple --which_model_netG resnet50ASPP --name resnet50ASPP`
47+
48+
To test the model using Resnet50 using ASPP module
49+
50+
`!python test.py --dataroot ./ --dataset_mode single --which_model_netG resnet50ASPP --ntest 8 --model test`
51+
52+
## Results
53+
54+
#### Input:
55+
![input](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/2.png)
56+
57+
#### Trimap:
58+
![Trimap](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey_tri.png)
59+
60+
61+
#### AlphaGAN matting :
62+
##### Generator:Resnet50,Discriminator:N Layer Patch GAN
63+
![Output2](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey_resnet50.png)
64+
65+
##### Generator:Resnet50,Discriminator:Pixel Patch GAN
66+
![Output3](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey.png)
67+
68+
##### Generator:Resnet50 + ASPP module,Discriminator:N Layer Patch GAN
69+
![Output4](https://raw.githubusercontent.com/Nerdyvedi/GSOC-Opencv-matting/master/donkey_deeplab.png)
70+
71+
72+
### Comparing with the Original implementation
73+
(Average Rank on alphamatting.com has been shown)
74+
75+
| Error type | Original implementation | Resnet50 +N Layer | Resnet50 + Pixel | Resnet50 + ASPP module |
76+
| ----------- | ------------------------ | ------------------- | ----------------- | -----------|
77+
| Sum of absolute differences | 11.7 | 42.8 | 43.8 | 53 |
78+
| Mean square error | 15 | 45.8 | 45.6 | 54.2 |
79+
| Gradient error | 14 | 52.9 | 52.7 | 55 |
80+
| Connectivity error | 29.6 | 23.3 | 22.6 | 32.8 |
81+
82+
83+
### Training dataset used
84+
I used the training dataset created by me using the software known as gimp.
85+
[Link to created dataset](https://drive.google.com/open?id=1zQbk2Cu7QOBwzg4vVGqCWJwHGTwGppFe)
86+
87+
### What could be wrong ?
88+
89+
1.The dataset I created is not perfect. I tried to make it as perfect by marking every pixel in some cases as well, but it still ain't perfect.
90+
91+
2. The output of the generator is 320x320, then the alpha matte is resized to the original image. Maybe there is a loss in resizing the image, and a better upsampling method might just imporove the outputs.
92+
93+
3. The major reason the implementation isn't as good as original is that the author hasn't clearly mentioned how the skip connections are used. I am very sure about the architecture of the encoder and decoder, but the only thing I am unsure about is how skip connections are used.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class BaseDataLoader():
2+
def __init__(self):
3+
pass
4+
def initialize(self, opt):
5+
self.opt = opt
6+
pass
7+
8+
def load_data():
9+
return None
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch.utils.data as data
2+
from PIL import Image
3+
import torchvision.transforms as transforms
4+
5+
class BaseDataset(data.Dataset):
6+
def __init__(self):
7+
super(BaseDataset, self).__init__()
8+
9+
def name(self):
10+
return 'BaseDataset'
11+
12+
def initialize(self, opt):
13+
pass
14+
15+
def get_transform(opt):
16+
transform_list = []
17+
if opt.resize_or_crop == 'resize_and_crop':
18+
osize = [opt.loadSize, opt.loadSize]
19+
transform_list.append(transforms.Scale(osize, Image.BICUBIC))
20+
transform_list.append(transforms.RandomCrop(opt.fineSize))
21+
elif opt.resize_or_crop == 'crop':
22+
transform_list.append(transforms.RandomCrop(opt.fineSize))
23+
elif opt.resize_or_crop == 'scale_width':
24+
transform_list.append(transforms.Lambda(
25+
lambda img: __scale_width(img, opt.fineSize)))
26+
elif opt.resize_or_crop == 'scale_width_and_crop':
27+
transform_list.append(transforms.Lambda(
28+
lambda img: __scale_width(img, opt.loadSize)))
29+
transform_list.append(transforms.RandomCrop(opt.fineSize))
30+
31+
if opt.isTrain and not opt.no_flip:
32+
transform_list.append(transforms.RandomHorizontalFlip())
33+
34+
transform_list += [transforms.ToTensor(),
35+
transforms.Normalize((0.5, 0.5, 0.5),
36+
(0.5, 0.5, 0.5))]
37+
return transforms.Compose(transform_list)
38+
39+
def __scale_width(img, target_width):
40+
ow, oh = img.size
41+
if (ow == target_width):
42+
return img
43+
w = target_width
44+
h = int(target_width * oh / ow)
45+
return img.resize((w, h), Image.BICUBIC)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch.utils.data
2+
from data.base_data_loader import BaseDataLoader
3+
4+
5+
def CreateDataset(opt):
6+
dataset = None
7+
8+
elif opt.dataset_mode == 'testData':
9+
from data.test_dataset import TestDataset
10+
dataset = TestDataset()
11+
elif opt.dataset_mode == 'generated_simple':
12+
from data.generated_dataset_simple import GeneratedDatasetSimple
13+
dataset = GeneratedDatasetSimple()
14+
else:
15+
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
16+
17+
print("dataset [%s] was created" % (dataset.name()))
18+
dataset.initialize(opt)
19+
return dataset
20+
21+
22+
class CustomDatasetDataLoader(BaseDataLoader):
23+
def name(self):
24+
return 'CustomDatasetDataLoader'
25+
26+
def initialize(self, opt):
27+
BaseDataLoader.initialize(self, opt)
28+
self.dataset = CreateDataset(opt)
29+
self.dataloader = torch.utils.data.DataLoader(
30+
self.dataset,
31+
batch_size=opt.batchSize,
32+
shuffle=not opt.serial_batches,
33+
num_workers=int(opt.nThreads))
34+
35+
def load_data(self):
36+
return self
37+
38+
def __len__(self):
39+
return min(len(self.dataset), self.opt.max_dataset_size)
40+
41+
def __iter__(self):
42+
for i, data in enumerate(self.dataloader):
43+
if i >= self.opt.max_dataset_size:
44+
break
45+
yield data
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
def CreateDataLoader(opt):
2+
from data.custom_dataset_data_loader import CustomDatasetDataLoader
3+
data_loader = CustomDatasetDataLoader()
4+
print(data_loader.name())
5+
data_loader.initialize(opt)
6+
return data_loader
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
import os.path
2+
import torchvision.transforms as transforms
3+
from data.base_dataset import BaseDataset, get_transform
4+
from data.image_folder import make_dataset
5+
from PIL import Image
6+
import PIL
7+
import random
8+
import scipy.ndimage
9+
import numpy as np
10+
import math
11+
# import pbcvt
12+
# import colour_transfer
13+
14+
class GeneratedDatasetSimple(BaseDataset):
15+
def initialize(self, opt):
16+
self.opt = opt
17+
self.root = opt.dataroot
18+
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
19+
self.dir_alpha = os.path.join(self.dir_AB, 'alpha')
20+
self.dir_fg = os.path.join(self.dir_AB, 'fg')
21+
self.dir_bg = os.path.join(self.dir_AB, 'bg')
22+
self.alpha_paths = sorted(make_dataset(self.dir_alpha))
23+
self.fg_paths = sorted(make_dataset(self.dir_fg))
24+
self.bg_paths = make_dataset(self.dir_bg)
25+
self.alpha_size = len(self.alpha_paths)
26+
self.bg_size = len(self.bg_paths)
27+
28+
29+
def __getitem__(self, index):
30+
index = index % self.alpha_size
31+
alpha_path = self.alpha_paths[index]
32+
fg_path = self.fg_paths[index]
33+
index_bg = random.randint(0, self.bg_size - 1)
34+
bg_path = self.bg_paths[index_bg]
35+
36+
37+
A_bg = Image.open(bg_path).convert('RGB')
38+
A_fg = Image.open(fg_path).convert('RGB')
39+
40+
A_alpha = Image.open(alpha_path).convert('L')
41+
assert A_alpha.mode == 'L'
42+
43+
44+
A_trimap = self.generate_trimap(A_alpha)
45+
46+
# A_bg = self.resize_bg(A_bg, A_fg)
47+
w_bg, h_bg = A_bg.size
48+
if w_bg < 321 or h_bg < 321:
49+
x = w_bg if w_bg < h_bg else h_bg
50+
ratio = 321/float(x)
51+
A_bg = A_bg.resize((int(np.ceil(w_bg*ratio)+1),int(np.ceil(h_bg*ratio)+1)), Image.BICUBIC)
52+
w_bg, h_bg = A_bg.size
53+
assert w_bg > 320 and h_bg > 320, '{} {}'.format(w_bg, h_bg)
54+
x = random.randint(0, w_bg-320-1)
55+
y = random.randint(0, h_bg-320-1)
56+
A_bg = A_bg.crop((x,y, x+320, y+320))
57+
58+
crop_size = random.choice([320,480,640])
59+
# crop_size = random.choice([320,400,480,560,640,720])
60+
crop_center = self.find_crop_center(A_trimap)
61+
start_index_height = max(min(A_fg.size[1]-crop_size, crop_center[0] - int(crop_size/2) + 1), 0)
62+
start_index_width = max(min(A_fg.size[0]-crop_size, crop_center[1] - int(crop_size/2) + 1), 0)
63+
64+
bbox = ((start_index_width,start_index_height,start_index_width+crop_size,start_index_height+crop_size))
65+
66+
# A_bg = A_bg.crop(bbox)
67+
A_fg = A_fg.crop(bbox)
68+
A_alpha = A_alpha.crop(bbox)
69+
A_trimap = A_trimap.crop(bbox)
70+
71+
if self.opt.which_model_netG == 'unet_256':
72+
A_bg = A_bg.resize((256,256))
73+
A_fg = A_fg.resize((256,256))
74+
A_alpha = A_alpha.resize((256,256))
75+
A_trimap = A_trimap.resize((256,256))
76+
assert A_alpha.mode == 'L'
77+
else:
78+
A_bg = A_bg.resize((320,320))
79+
A_fg = A_fg.resize((320,320))
80+
A_alpha = A_alpha.resize((320,320))
81+
A_trimap = A_trimap.resize((320,320))
82+
assert A_alpha.mode == 'L'
83+
84+
if random.randint(0, 1):
85+
A_bg = A_bg.transpose(PIL.Image.FLIP_LEFT_RIGHT)
86+
87+
if random.randint(0, 1):
88+
A_fg = A_fg.transpose(PIL.Image.FLIP_LEFT_RIGHT)
89+
A_alpha = A_alpha.transpose(PIL.Image.FLIP_LEFT_RIGHT)
90+
A_trimap = A_trimap.transpose(PIL.Image.FLIP_LEFT_RIGHT)
91+
92+
## COLOR TRANSFER ##
93+
# if random.randint(0, 2) != 0:
94+
# A_old = A_fg
95+
# target = np.array(A_fg)
96+
# palette = np.array(A_palette)
97+
# recolor = colour_transfer.runCT(target, palette)
98+
# A_fg = Image.fromarray(recolor)
99+
100+
if self.opt.which_direction == 'BtoA':
101+
input_nc = self.opt.output_nc
102+
output_nc = self.opt.input_nc
103+
else:
104+
input_nc = self.opt.input_nc
105+
output_nc = self.opt.output_nc
106+
107+
A_bg = transforms.ToTensor()(A_bg)
108+
A_fg = transforms.ToTensor()(A_fg)
109+
A_alpha = transforms.ToTensor()(A_alpha)
110+
A_trimap = transforms.ToTensor()(A_trimap)
111+
112+
return {'A_bg': A_bg,
113+
'A_fg': A_fg,
114+
'A_alpha': A_alpha,
115+
'A_trimap': A_trimap,
116+
'A_paths': alpha_path}
117+
118+
def resize_bg(self, bg, fg):
119+
bbox = fg.size
120+
w = bbox[0]
121+
h = bbox[1]
122+
bg_bbox = bg.size
123+
bw = bg_bbox[0]
124+
bh = bg_bbox[1]
125+
wratio = w / float(bw)
126+
hratio = h / float(bh)
127+
ratio = wratio if wratio > hratio else hratio
128+
if ratio > 1:
129+
bg = bg.resize((int(np.ceil(bw*ratio)+1),int(np.ceil(bh*ratio)+1)), Image.BICUBIC)
130+
bg = bg.crop((0,0,w,h))
131+
132+
return bg
133+
134+
# def generate_trimap(self, alpha):
135+
# trimap = np.array(alpha)
136+
# kernel_sizes = [val for val in range(5,40)]
137+
# kernel = random.choice(kernel_sizes)
138+
# trimap[np.where((scipy.ndimage.grey_dilation(alpha,size=(kernel,kernel)) - alpha!=0))] = 128
139+
140+
# return Image.fromarray(trimap)
141+
def generate_trimap(self, alpha):
142+
trimap = np.array(alpha)
143+
grey = np.zeros_like(trimap)
144+
kernel_sizes = [val for val in range(2,20)]
145+
kernel = random.choice(kernel_sizes)
146+
# trimap[np.where((scipy.ndimage.grey_dilation(alpha,size=(kernel,kernel)) - alpha!=0))] = 128
147+
grey = np.where(np.logical_and(trimap>0, trimap<255), 128, 0)
148+
grey = scipy.ndimage.grey_dilation(grey, size=(kernel,kernel))
149+
trimap[grey==128] = 128
150+
151+
return Image.fromarray(trimap)
152+
153+
def find_crop_center(self, trimap):
154+
t = np.array(trimap)
155+
target = np.where(t==128)
156+
index = random.choice([i for i in range(len(target[0]))])
157+
return np.array(target)[:,index][:2]
158+
159+
def rotatedRectWithMaxArea(self, w, h, angle):
160+
"""
161+
Given a rectangle of size wxh that has been rotated by 'angle' (in
162+
radians), computes the width and height of the largest possible
163+
axis-aligned rectangle (maximal area) within the rotated rectangle.
164+
"""
165+
if w <= 0 or h <= 0:
166+
return 0,0
167+
width_is_longer = w >= h
168+
side_long, side_short = (w,h) if width_is_longer else (h,w)
169+
170+
# since the solutions for angle, -angle and 180-angle are all the same,
171+
# if suffices to look at the first quadrant and the absolute values of sin,cos:
172+
sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle))
173+
if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10:
174+
# half constrained case: two crop corners touch the longer side,
175+
# the other two corners are on the mid-line parallel to the longer line
176+
x = 0.5*side_short
177+
wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a)
178+
else:
179+
# fully constrained case: crop touches all 4 sides
180+
cos_2a = cos_a*cos_a - sin_a*sin_a
181+
wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a
182+
183+
return wr,hr
184+
185+
def __len__(self):
186+
return len(self.alpha_paths)
187+
188+
def name(self):
189+
return 'GeneratedDataset'

0 commit comments

Comments
 (0)