Skip to content

Commit 9346e06

Browse files
committed
Added implementations of AlphaGAN matting and Global sampling method for matting
Add files via upload Delete AlphaGAN.py Delete alphaGAN_eval.py Delete alphaGAN_test.py Delete alphaGAN_train.py Delete __init__.py Delete composition_code.py Delete crop_train.py Delete input_dataset.py Delete trainset_filter.py Delete visualize.py Delete README.md Delete README.md Delete globalmatting.cpp Delete globalmatting.h Delete guidedfilter.cpp Delete guidedfilter.h Delete doll.png Delete donkey.png Delete elephant.png Delete net.png Delete pineapple.png Delete plant.png Delete troll.png Delete plasticbag.png Delete doll.png Delete donkey.png Delete elephant.png Delete net.png Delete pineapple.png Delete plant.png Delete plasticbag.png Delete troll.png Rename AlphaGAN.py to alphagan.py
1 parent b6e2867 commit 9346e06

32 files changed

+2511
-0
lines changed
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
import torchvision.transforms as transforms
2+
import numpy as np
3+
import torch as t
4+
from model.AlphaGAN import NetG
5+
from PIL import Image
6+
import os
7+
from visualize import Visualizer
8+
import tqdm
9+
10+
os.environ["CUDA_VISIBLE_DEVICES"] = '2'
11+
vis = Visualizer('alphaGAN_eval')
12+
13+
transform = transforms.Compose([
14+
transforms.ToTensor(),
15+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
16+
])
17+
18+
to_pil = transforms.Compose([
19+
transforms.ToPILImage()
20+
])
21+
22+
MODEL_DIR = '/data1/zzl/model/alphaGAN/new_trainset/netG/netG_40.pth'
23+
24+
25+
def padding_img(img):
26+
27+
img_size = np.shape(img)
28+
if len(img_size) == 3:
29+
(h, w, c) = img_size
30+
w_padding = (int(w/320) + 1) * 320
31+
h_padding = (int(h/320) + 1) * 320
32+
33+
padding_result = np.pad(img, ((0, h_padding - h), (0, w_padding - w), (0, 0)), 'mean')
34+
35+
return Image.fromarray(padding_result), int(w_padding/320), int(h_padding/320)
36+
elif len(img_size) == 2:
37+
(h, w) = img_size
38+
w_padding = (int(w/320) + 1) * 320
39+
h_padding = (int(h/320) + 1) * 320
40+
41+
padding_result = np.pad(img, ((0, h_padding - h), (0, w_padding - w)), 'constant', constant_values=0)
42+
43+
return Image.fromarray(padding_result), int(w_padding/320), int(h_padding/320)
44+
else:
45+
exit(1)
46+
47+
48+
def clip_img(img, h_clip, w_clip):
49+
50+
img_list = []
51+
for x in range(w_clip):
52+
for y in range(h_clip):
53+
region = (x*320, y*320, x*320+320, y*320+320)
54+
crop_img = img.crop(region)
55+
crop_img = transform(crop_img)
56+
crop_img = crop_img[None]
57+
58+
img_list.append(crop_img)
59+
60+
crop_img = img_list[0]
61+
for i in range(1, len(img_list)):
62+
crop_img = t.cat((crop_img, img_list[i]), dim=0)
63+
64+
return crop_img
65+
66+
67+
def combination(img_list, h_clip, w_clip):
68+
69+
column =[]
70+
for y in range(w_clip):
71+
for x in range(h_clip):
72+
if x == 0:
73+
column.append(img_list[y*h_clip + x])
74+
else:
75+
column[y] = t.cat((column[y], img_list[y*h_clip + x]), dim=1)
76+
77+
com = column[0]
78+
79+
for i in range(1, len(column)):
80+
com = t.cat((com, column[i]), dim=2)
81+
82+
return com
83+
84+
85+
def clip_input():
86+
with t.no_grad():
87+
net_G = NetG().cuda()
88+
net_G.eval()
89+
net_G.load_state_dict(t.load(MODEL_DIR, map_location=t.device('cpu')))
90+
91+
img_root = '/data1/zzl/dataset/matting/alphamatting/input_lowers'
92+
trimap_root = '/data1/zzl/dataset/matting/alphamatting/trimap_lowres'
93+
94+
img_name = os.listdir(img_root)
95+
96+
for name in tqdm.tqdm(img_name):
97+
for i in range(1, 4):
98+
99+
trimap_floder = 'Trimap' + str(i)
100+
101+
img = Image.open(os.path.join(img_root, name))
102+
print('img_size', img.size)
103+
print('img_shape', np.shape(img))
104+
img, w_clip, h_clip = padding_img(img)
105+
print('img.shape', np.shape(img))
106+
# print('img', w_clip, h_clip)
107+
108+
crop_img = clip_img(img, h_clip, w_clip)
109+
110+
img = transform(img)
111+
112+
trimap = Image.open(os.path.join(trimap_root, trimap_floder, name))
113+
(h_r, w_r) = np.shape(trimap)
114+
trimap, w_clip, h_clip = padding_img(trimap)
115+
116+
# print('trimap', w_clip, h_clip)
117+
118+
crop_tri = clip_img(trimap, h_clip, w_clip)
119+
120+
input_img = t.cat((crop_img, crop_tri), dim=1)
121+
input_img = input_img.cuda()
122+
123+
fake_alpha = net_G(input_img)
124+
125+
com_fake = combination(fake_alpha, h_clip, w_clip)
126+
127+
vis.images(com_fake.cpu().numpy(), win='fake_alpha')
128+
vis.images(img.numpy() * 0.5 + 0.5, win='input')
129+
# print(fake_alpha[0].size())
130+
# print(com_fake.size())
131+
save_alpha = to_pil(com_fake.cpu())
132+
save_alpha = save_alpha.convert('L')
133+
print('fake_alpha.shape', np.shape(save_alpha))
134+
box = (0, 0, w_r, h_r)
135+
save_alpha = save_alpha.crop(box)
136+
137+
if not os.path.exists(trimap_floder):
138+
os.mkdir(trimap_floder)
139+
print('save_alpha.shape', np.shape(save_alpha))
140+
save_alpha.save(trimap_floder + '/' + name)
141+
return
142+
143+
144+
def full_input():
145+
with t.no_grad():
146+
net_G = NetG().cuda()
147+
net_G.eval()
148+
net_G.load_state_dict(t.load(MODEL_DIR, map_location=t.device('cpu')))
149+
150+
img_root = '/data1/zzl/dataset/matting/alphamatting/input_lowers'
151+
trimap_root = '/data1/zzl/dataset/matting/alphamatting/trimap_lowres'
152+
153+
img_name = os.listdir(img_root)
154+
155+
for name in tqdm.tqdm(img_name):
156+
for i in range(1, 4):
157+
158+
trimap_floder = 'Trimap' + str(i)
159+
160+
img = Image.open(os.path.join(img_root, name))
161+
img, _1, _2 = padding_img(img)
162+
img = transform(img)
163+
164+
trimap = Image.open(os.path.join(trimap_root, trimap_floder, name))
165+
(h_r, w_r) = np.shape(trimap)
166+
trimap, _1, _2 = padding_img(trimap)
167+
(w, h) = np.shape(trimap)
168+
trimap = np.reshape(trimap, (w, h, 1))
169+
trimap = transform(trimap)
170+
171+
input_img = t.cat((img, trimap), dim=0)
172+
input_img = input_img[None]
173+
input_img = input_img.cuda()
174+
175+
fake_alpha = net_G(input_img)
176+
vis.images(fake_alpha.cpu().numpy(), win='fake_alpha')
177+
vis.images(img.numpy() * 0.5 + 0.5, win='input')
178+
#print(fake_alpha[0].size())
179+
save_alpha = to_pil(fake_alpha.cpu()[0])
180+
save_alpha = save_alpha.convert('L')
181+
182+
box = (0, 0, w_r, h_r)
183+
save_alpha = save_alpha.crop(box)
184+
185+
if not os.path.exists(trimap_floder):
186+
os.mkdir(trimap_floder)
187+
print(np.shape(save_alpha))
188+
save_alpha.save(trimap_floder + '/' + name)
189+
190+
191+
def resize_input():
192+
with t.no_grad():
193+
net_G = NetG().cuda()
194+
net_G.eval()
195+
net_G.load_state_dict(t.load(MODEL_DIR, map_location=t.device('cpu')))
196+
197+
img_root = '/data1/zzl/dataset/matting/alphamatting/input_lowers'
198+
trimap_root = '/data1/zzl/dataset/matting/alphamatting/trimap_lowres'
199+
200+
img_name = os.listdir(img_root)
201+
202+
for name in tqdm.tqdm(img_name):
203+
for i in range(1, 4):
204+
205+
trimap_floder = 'Trimap' + str(i)
206+
207+
img = Image.open(os.path.join(img_root, name))
208+
(w, h) = img.size
209+
w_large = w//320 + 1
210+
h_large = h//320 + 1
211+
212+
img = img.resize((w_large * 320, h_large * 320))
213+
214+
img = transform(img)
215+
216+
trimap = Image.open(os.path.join(trimap_root, trimap_floder, name))
217+
trimap = trimap.resize((w_large * 320, h_large * 320))
218+
(w, h) = np.shape(trimap)
219+
trimap = np.reshape(trimap, (w, h, 1))
220+
trimap = transform(trimap)
221+
222+
input_img = t.cat((img, trimap), dim=0)
223+
input_img = input_img[None]
224+
input_img = input_img.cuda()
225+
226+
fake_alpha = net_G(input_img)
227+
vis.images(fake_alpha.cpu().numpy(), win='fake_alpha')
228+
vis.images(img.numpy() * 0.5 + 0.5, win='input')
229+
# print(fake_alpha[0].size())
230+
save_alpha = to_pil(fake_alpha.cpu()[0])
231+
save_alpha = save_alpha.convert('L')
232+
box = (0, 0, w, h)
233+
save_alpha = save_alpha.crop(box)
234+
if not os.path.exists(trimap_floder):
235+
os.mkdir(trimap_floder)
236+
print(np.shape(save_alpha))
237+
save_alpha.save(trimap_floder + '/' + name)
238+
239+
240+
full_input()
241+

0 commit comments

Comments
 (0)