|
| 1 | +import os.path |
| 2 | +import torchvision.transforms as transforms |
| 3 | +from data.base_dataset import BaseDataset, get_transform |
| 4 | +from data.image_folder import make_dataset |
| 5 | +from PIL import Image |
| 6 | +import PIL |
| 7 | +import random |
| 8 | +import scipy.ndimage |
| 9 | +import numpy as np |
| 10 | +import math |
| 11 | +# import pbcvt |
| 12 | +# import colour_transfer |
| 13 | + |
| 14 | +class GeneratedDatasetSimple(BaseDataset): |
| 15 | + def initialize(self, opt): |
| 16 | + self.opt = opt |
| 17 | + self.root = opt.dataroot |
| 18 | + self.dir_AB = os.path.join(opt.dataroot, opt.phase) |
| 19 | + self.dir_alpha = os.path.join(self.dir_AB, 'alpha') |
| 20 | + self.dir_fg = os.path.join(self.dir_AB, 'fg') |
| 21 | + self.dir_bg = os.path.join(self.dir_AB, 'bg') |
| 22 | + self.alpha_paths = sorted(make_dataset(self.dir_alpha)) |
| 23 | + self.fg_paths = sorted(make_dataset(self.dir_fg)) |
| 24 | + self.bg_paths = make_dataset(self.dir_bg) |
| 25 | + self.alpha_size = len(self.alpha_paths) |
| 26 | + self.bg_size = len(self.bg_paths) |
| 27 | + |
| 28 | + |
| 29 | + def __getitem__(self, index): |
| 30 | + index = index % self.alpha_size |
| 31 | + alpha_path = self.alpha_paths[index] |
| 32 | + fg_path = self.fg_paths[index] |
| 33 | + index_bg = random.randint(0, self.bg_size - 1) |
| 34 | + bg_path = self.bg_paths[index_bg] |
| 35 | + |
| 36 | + |
| 37 | + A_bg = Image.open(bg_path).convert('RGB') |
| 38 | + A_fg = Image.open(fg_path).convert('RGB') |
| 39 | + |
| 40 | + A_alpha = Image.open(alpha_path).convert('L') |
| 41 | + assert A_alpha.mode == 'L' |
| 42 | + |
| 43 | + |
| 44 | + A_trimap = self.generate_trimap(A_alpha) |
| 45 | + |
| 46 | + # A_bg = self.resize_bg(A_bg, A_fg) |
| 47 | + w_bg, h_bg = A_bg.size |
| 48 | + if w_bg < 321 or h_bg < 321: |
| 49 | + x = w_bg if w_bg < h_bg else h_bg |
| 50 | + ratio = 321/float(x) |
| 51 | + A_bg = A_bg.resize((int(np.ceil(w_bg*ratio)+1),int(np.ceil(h_bg*ratio)+1)), Image.BICUBIC) |
| 52 | + w_bg, h_bg = A_bg.size |
| 53 | + assert w_bg > 320 and h_bg > 320, '{} {}'.format(w_bg, h_bg) |
| 54 | + x = random.randint(0, w_bg-320-1) |
| 55 | + y = random.randint(0, h_bg-320-1) |
| 56 | + A_bg = A_bg.crop((x,y, x+320, y+320)) |
| 57 | + |
| 58 | + crop_size = random.choice([320,480,640]) |
| 59 | + # crop_size = random.choice([320,400,480,560,640,720]) |
| 60 | + crop_center = self.find_crop_center(A_trimap) |
| 61 | + start_index_height = max(min(A_fg.size[1]-crop_size, crop_center[0] - int(crop_size/2) + 1), 0) |
| 62 | + start_index_width = max(min(A_fg.size[0]-crop_size, crop_center[1] - int(crop_size/2) + 1), 0) |
| 63 | + |
| 64 | + bbox = ((start_index_width,start_index_height,start_index_width+crop_size,start_index_height+crop_size)) |
| 65 | + |
| 66 | + # A_bg = A_bg.crop(bbox) |
| 67 | + A_fg = A_fg.crop(bbox) |
| 68 | + A_alpha = A_alpha.crop(bbox) |
| 69 | + A_trimap = A_trimap.crop(bbox) |
| 70 | + |
| 71 | + if self.opt.which_model_netG == 'unet_256': |
| 72 | + A_bg = A_bg.resize((256,256)) |
| 73 | + A_fg = A_fg.resize((256,256)) |
| 74 | + A_alpha = A_alpha.resize((256,256)) |
| 75 | + A_trimap = A_trimap.resize((256,256)) |
| 76 | + assert A_alpha.mode == 'L' |
| 77 | + else: |
| 78 | + A_bg = A_bg.resize((320,320)) |
| 79 | + A_fg = A_fg.resize((320,320)) |
| 80 | + A_alpha = A_alpha.resize((320,320)) |
| 81 | + A_trimap = A_trimap.resize((320,320)) |
| 82 | + assert A_alpha.mode == 'L' |
| 83 | + |
| 84 | + if random.randint(0, 1): |
| 85 | + A_bg = A_bg.transpose(PIL.Image.FLIP_LEFT_RIGHT) |
| 86 | + |
| 87 | + if random.randint(0, 1): |
| 88 | + A_fg = A_fg.transpose(PIL.Image.FLIP_LEFT_RIGHT) |
| 89 | + A_alpha = A_alpha.transpose(PIL.Image.FLIP_LEFT_RIGHT) |
| 90 | + A_trimap = A_trimap.transpose(PIL.Image.FLIP_LEFT_RIGHT) |
| 91 | + |
| 92 | + ## COLOR TRANSFER ## |
| 93 | + # if random.randint(0, 2) != 0: |
| 94 | + # A_old = A_fg |
| 95 | + # target = np.array(A_fg) |
| 96 | + # palette = np.array(A_palette) |
| 97 | + # recolor = colour_transfer.runCT(target, palette) |
| 98 | + # A_fg = Image.fromarray(recolor) |
| 99 | + |
| 100 | + if self.opt.which_direction == 'BtoA': |
| 101 | + input_nc = self.opt.output_nc |
| 102 | + output_nc = self.opt.input_nc |
| 103 | + else: |
| 104 | + input_nc = self.opt.input_nc |
| 105 | + output_nc = self.opt.output_nc |
| 106 | + |
| 107 | + A_bg = transforms.ToTensor()(A_bg) |
| 108 | + A_fg = transforms.ToTensor()(A_fg) |
| 109 | + A_alpha = transforms.ToTensor()(A_alpha) |
| 110 | + A_trimap = transforms.ToTensor()(A_trimap) |
| 111 | + |
| 112 | + return {'A_bg': A_bg, |
| 113 | + 'A_fg': A_fg, |
| 114 | + 'A_alpha': A_alpha, |
| 115 | + 'A_trimap': A_trimap, |
| 116 | + 'A_paths': alpha_path} |
| 117 | + |
| 118 | + def resize_bg(self, bg, fg): |
| 119 | + bbox = fg.size |
| 120 | + w = bbox[0] |
| 121 | + h = bbox[1] |
| 122 | + bg_bbox = bg.size |
| 123 | + bw = bg_bbox[0] |
| 124 | + bh = bg_bbox[1] |
| 125 | + wratio = w / float(bw) |
| 126 | + hratio = h / float(bh) |
| 127 | + ratio = wratio if wratio > hratio else hratio |
| 128 | + if ratio > 1: |
| 129 | + bg = bg.resize((int(np.ceil(bw*ratio)+1),int(np.ceil(bh*ratio)+1)), Image.BICUBIC) |
| 130 | + bg = bg.crop((0,0,w,h)) |
| 131 | + |
| 132 | + return bg |
| 133 | + |
| 134 | + # def generate_trimap(self, alpha): |
| 135 | + # trimap = np.array(alpha) |
| 136 | + # kernel_sizes = [val for val in range(5,40)] |
| 137 | + # kernel = random.choice(kernel_sizes) |
| 138 | + # trimap[np.where((scipy.ndimage.grey_dilation(alpha,size=(kernel,kernel)) - alpha!=0))] = 128 |
| 139 | + |
| 140 | + # return Image.fromarray(trimap) |
| 141 | + def generate_trimap(self, alpha): |
| 142 | + trimap = np.array(alpha) |
| 143 | + grey = np.zeros_like(trimap) |
| 144 | + kernel_sizes = [val for val in range(2,20)] |
| 145 | + kernel = random.choice(kernel_sizes) |
| 146 | + # trimap[np.where((scipy.ndimage.grey_dilation(alpha,size=(kernel,kernel)) - alpha!=0))] = 128 |
| 147 | + grey = np.where(np.logical_and(trimap>0, trimap<255), 128, 0) |
| 148 | + grey = scipy.ndimage.grey_dilation(grey, size=(kernel,kernel)) |
| 149 | + trimap[grey==128] = 128 |
| 150 | + |
| 151 | + return Image.fromarray(trimap) |
| 152 | + |
| 153 | + def find_crop_center(self, trimap): |
| 154 | + t = np.array(trimap) |
| 155 | + target = np.where(t==128) |
| 156 | + index = random.choice([i for i in range(len(target[0]))]) |
| 157 | + return np.array(target)[:,index][:2] |
| 158 | + |
| 159 | + def rotatedRectWithMaxArea(self, w, h, angle): |
| 160 | + """ |
| 161 | + Given a rectangle of size wxh that has been rotated by 'angle' (in |
| 162 | + radians), computes the width and height of the largest possible |
| 163 | + axis-aligned rectangle (maximal area) within the rotated rectangle. |
| 164 | + """ |
| 165 | + if w <= 0 or h <= 0: |
| 166 | + return 0,0 |
| 167 | + width_is_longer = w >= h |
| 168 | + side_long, side_short = (w,h) if width_is_longer else (h,w) |
| 169 | + |
| 170 | + # since the solutions for angle, -angle and 180-angle are all the same, |
| 171 | + # if suffices to look at the first quadrant and the absolute values of sin,cos: |
| 172 | + sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) |
| 173 | + if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10: |
| 174 | + # half constrained case: two crop corners touch the longer side, |
| 175 | + # the other two corners are on the mid-line parallel to the longer line |
| 176 | + x = 0.5*side_short |
| 177 | + wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a) |
| 178 | + else: |
| 179 | + # fully constrained case: crop touches all 4 sides |
| 180 | + cos_2a = cos_a*cos_a - sin_a*sin_a |
| 181 | + wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a |
| 182 | + |
| 183 | + return wr,hr |
| 184 | + |
| 185 | + def __len__(self): |
| 186 | + return len(self.alpha_paths) |
| 187 | + |
| 188 | + def name(self): |
| 189 | + return 'GeneratedDataset' |
0 commit comments