|
| 1 | +import torchvision.transforms as transforms |
| 2 | +import numpy as np |
| 3 | +import torch as t |
| 4 | +from model.AlphaGAN import NetG |
| 5 | +from PIL import Image |
| 6 | +import os |
| 7 | +from visualize import Visualizer |
| 8 | +import tqdm |
| 9 | + |
| 10 | +os.environ["CUDA_VISIBLE_DEVICES"] = '2' |
| 11 | +vis = Visualizer('alphaGAN_eval') |
| 12 | + |
| 13 | +transform = transforms.Compose([ |
| 14 | + transforms.ToTensor(), |
| 15 | + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| 16 | + ]) |
| 17 | + |
| 18 | +to_pil = transforms.Compose([ |
| 19 | + transforms.ToPILImage() |
| 20 | +]) |
| 21 | + |
| 22 | +MODEL_DIR = '/data1/zzl/model/alphaGAN/new_trainset/netG/netG_40.pth' |
| 23 | + |
| 24 | + |
| 25 | +def padding_img(img): |
| 26 | + |
| 27 | + img_size = np.shape(img) |
| 28 | + if len(img_size) == 3: |
| 29 | + (h, w, c) = img_size |
| 30 | + w_padding = (int(w/320) + 1) * 320 |
| 31 | + h_padding = (int(h/320) + 1) * 320 |
| 32 | + |
| 33 | + padding_result = np.pad(img, ((0, h_padding - h), (0, w_padding - w), (0, 0)), 'mean') |
| 34 | + |
| 35 | + return Image.fromarray(padding_result), int(w_padding/320), int(h_padding/320) |
| 36 | + elif len(img_size) == 2: |
| 37 | + (h, w) = img_size |
| 38 | + w_padding = (int(w/320) + 1) * 320 |
| 39 | + h_padding = (int(h/320) + 1) * 320 |
| 40 | + |
| 41 | + padding_result = np.pad(img, ((0, h_padding - h), (0, w_padding - w)), 'constant', constant_values=0) |
| 42 | + |
| 43 | + return Image.fromarray(padding_result), int(w_padding/320), int(h_padding/320) |
| 44 | + else: |
| 45 | + exit(1) |
| 46 | + |
| 47 | + |
| 48 | +def clip_img(img, h_clip, w_clip): |
| 49 | + |
| 50 | + img_list = [] |
| 51 | + for x in range(w_clip): |
| 52 | + for y in range(h_clip): |
| 53 | + region = (x*320, y*320, x*320+320, y*320+320) |
| 54 | + crop_img = img.crop(region) |
| 55 | + crop_img = transform(crop_img) |
| 56 | + crop_img = crop_img[None] |
| 57 | + |
| 58 | + img_list.append(crop_img) |
| 59 | + |
| 60 | + crop_img = img_list[0] |
| 61 | + for i in range(1, len(img_list)): |
| 62 | + crop_img = t.cat((crop_img, img_list[i]), dim=0) |
| 63 | + |
| 64 | + return crop_img |
| 65 | + |
| 66 | + |
| 67 | +def combination(img_list, h_clip, w_clip): |
| 68 | + |
| 69 | + column =[] |
| 70 | + for y in range(w_clip): |
| 71 | + for x in range(h_clip): |
| 72 | + if x == 0: |
| 73 | + column.append(img_list[y*h_clip + x]) |
| 74 | + else: |
| 75 | + column[y] = t.cat((column[y], img_list[y*h_clip + x]), dim=1) |
| 76 | + |
| 77 | + com = column[0] |
| 78 | + |
| 79 | + for i in range(1, len(column)): |
| 80 | + com = t.cat((com, column[i]), dim=2) |
| 81 | + |
| 82 | + return com |
| 83 | + |
| 84 | + |
| 85 | +def clip_input(): |
| 86 | + with t.no_grad(): |
| 87 | + net_G = NetG().cuda() |
| 88 | + net_G.eval() |
| 89 | + net_G.load_state_dict(t.load(MODEL_DIR, map_location=t.device('cpu'))) |
| 90 | + |
| 91 | + img_root = '/data1/zzl/dataset/matting/alphamatting/input_lowers' |
| 92 | + trimap_root = '/data1/zzl/dataset/matting/alphamatting/trimap_lowres' |
| 93 | + |
| 94 | + img_name = os.listdir(img_root) |
| 95 | + |
| 96 | + for name in tqdm.tqdm(img_name): |
| 97 | + for i in range(1, 4): |
| 98 | + |
| 99 | + trimap_floder = 'Trimap' + str(i) |
| 100 | + |
| 101 | + img = Image.open(os.path.join(img_root, name)) |
| 102 | + print('img_size', img.size) |
| 103 | + print('img_shape', np.shape(img)) |
| 104 | + img, w_clip, h_clip = padding_img(img) |
| 105 | + print('img.shape', np.shape(img)) |
| 106 | + # print('img', w_clip, h_clip) |
| 107 | + |
| 108 | + crop_img = clip_img(img, h_clip, w_clip) |
| 109 | + |
| 110 | + img = transform(img) |
| 111 | + |
| 112 | + trimap = Image.open(os.path.join(trimap_root, trimap_floder, name)) |
| 113 | + (h_r, w_r) = np.shape(trimap) |
| 114 | + trimap, w_clip, h_clip = padding_img(trimap) |
| 115 | + |
| 116 | + # print('trimap', w_clip, h_clip) |
| 117 | + |
| 118 | + crop_tri = clip_img(trimap, h_clip, w_clip) |
| 119 | + |
| 120 | + input_img = t.cat((crop_img, crop_tri), dim=1) |
| 121 | + input_img = input_img.cuda() |
| 122 | + |
| 123 | + fake_alpha = net_G(input_img) |
| 124 | + |
| 125 | + com_fake = combination(fake_alpha, h_clip, w_clip) |
| 126 | + |
| 127 | + vis.images(com_fake.cpu().numpy(), win='fake_alpha') |
| 128 | + vis.images(img.numpy() * 0.5 + 0.5, win='input') |
| 129 | + # print(fake_alpha[0].size()) |
| 130 | + # print(com_fake.size()) |
| 131 | + save_alpha = to_pil(com_fake.cpu()) |
| 132 | + save_alpha = save_alpha.convert('L') |
| 133 | + print('fake_alpha.shape', np.shape(save_alpha)) |
| 134 | + box = (0, 0, w_r, h_r) |
| 135 | + save_alpha = save_alpha.crop(box) |
| 136 | + |
| 137 | + if not os.path.exists(trimap_floder): |
| 138 | + os.mkdir(trimap_floder) |
| 139 | + print('save_alpha.shape', np.shape(save_alpha)) |
| 140 | + save_alpha.save(trimap_floder + '/' + name) |
| 141 | + return |
| 142 | + |
| 143 | + |
| 144 | +def full_input(): |
| 145 | + with t.no_grad(): |
| 146 | + net_G = NetG().cuda() |
| 147 | + net_G.eval() |
| 148 | + net_G.load_state_dict(t.load(MODEL_DIR, map_location=t.device('cpu'))) |
| 149 | + |
| 150 | + img_root = '/data1/zzl/dataset/matting/alphamatting/input_lowers' |
| 151 | + trimap_root = '/data1/zzl/dataset/matting/alphamatting/trimap_lowres' |
| 152 | + |
| 153 | + img_name = os.listdir(img_root) |
| 154 | + |
| 155 | + for name in tqdm.tqdm(img_name): |
| 156 | + for i in range(1, 4): |
| 157 | + |
| 158 | + trimap_floder = 'Trimap' + str(i) |
| 159 | + |
| 160 | + img = Image.open(os.path.join(img_root, name)) |
| 161 | + img, _1, _2 = padding_img(img) |
| 162 | + img = transform(img) |
| 163 | + |
| 164 | + trimap = Image.open(os.path.join(trimap_root, trimap_floder, name)) |
| 165 | + (h_r, w_r) = np.shape(trimap) |
| 166 | + trimap, _1, _2 = padding_img(trimap) |
| 167 | + (w, h) = np.shape(trimap) |
| 168 | + trimap = np.reshape(trimap, (w, h, 1)) |
| 169 | + trimap = transform(trimap) |
| 170 | + |
| 171 | + input_img = t.cat((img, trimap), dim=0) |
| 172 | + input_img = input_img[None] |
| 173 | + input_img = input_img.cuda() |
| 174 | + |
| 175 | + fake_alpha = net_G(input_img) |
| 176 | + vis.images(fake_alpha.cpu().numpy(), win='fake_alpha') |
| 177 | + vis.images(img.numpy() * 0.5 + 0.5, win='input') |
| 178 | + #print(fake_alpha[0].size()) |
| 179 | + save_alpha = to_pil(fake_alpha.cpu()[0]) |
| 180 | + save_alpha = save_alpha.convert('L') |
| 181 | + |
| 182 | + box = (0, 0, w_r, h_r) |
| 183 | + save_alpha = save_alpha.crop(box) |
| 184 | + |
| 185 | + if not os.path.exists(trimap_floder): |
| 186 | + os.mkdir(trimap_floder) |
| 187 | + print(np.shape(save_alpha)) |
| 188 | + save_alpha.save(trimap_floder + '/' + name) |
| 189 | + |
| 190 | + |
| 191 | +def resize_input(): |
| 192 | + with t.no_grad(): |
| 193 | + net_G = NetG().cuda() |
| 194 | + net_G.eval() |
| 195 | + net_G.load_state_dict(t.load(MODEL_DIR, map_location=t.device('cpu'))) |
| 196 | + |
| 197 | + img_root = '/data1/zzl/dataset/matting/alphamatting/input_lowers' |
| 198 | + trimap_root = '/data1/zzl/dataset/matting/alphamatting/trimap_lowres' |
| 199 | + |
| 200 | + img_name = os.listdir(img_root) |
| 201 | + |
| 202 | + for name in tqdm.tqdm(img_name): |
| 203 | + for i in range(1, 4): |
| 204 | + |
| 205 | + trimap_floder = 'Trimap' + str(i) |
| 206 | + |
| 207 | + img = Image.open(os.path.join(img_root, name)) |
| 208 | + (w, h) = img.size |
| 209 | + w_large = w//320 + 1 |
| 210 | + h_large = h//320 + 1 |
| 211 | + |
| 212 | + img = img.resize((w_large * 320, h_large * 320)) |
| 213 | + |
| 214 | + img = transform(img) |
| 215 | + |
| 216 | + trimap = Image.open(os.path.join(trimap_root, trimap_floder, name)) |
| 217 | + trimap = trimap.resize((w_large * 320, h_large * 320)) |
| 218 | + (w, h) = np.shape(trimap) |
| 219 | + trimap = np.reshape(trimap, (w, h, 1)) |
| 220 | + trimap = transform(trimap) |
| 221 | + |
| 222 | + input_img = t.cat((img, trimap), dim=0) |
| 223 | + input_img = input_img[None] |
| 224 | + input_img = input_img.cuda() |
| 225 | + |
| 226 | + fake_alpha = net_G(input_img) |
| 227 | + vis.images(fake_alpha.cpu().numpy(), win='fake_alpha') |
| 228 | + vis.images(img.numpy() * 0.5 + 0.5, win='input') |
| 229 | + # print(fake_alpha[0].size()) |
| 230 | + save_alpha = to_pil(fake_alpha.cpu()[0]) |
| 231 | + save_alpha = save_alpha.convert('L') |
| 232 | + box = (0, 0, w, h) |
| 233 | + save_alpha = save_alpha.crop(box) |
| 234 | + if not os.path.exists(trimap_floder): |
| 235 | + os.mkdir(trimap_floder) |
| 236 | + print(np.shape(save_alpha)) |
| 237 | + save_alpha.save(trimap_floder + '/' + name) |
| 238 | + |
| 239 | + |
| 240 | +full_input() |
| 241 | + |
0 commit comments