|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | + |
| 5 | +def define_G(which_model_netG, norm ='batch', init_type = 'normal',gpu_ids = [], pretrain = True): |
| 6 | + |
| 7 | + netG = None |
| 8 | + use_gpu = len(gpu_ids) > 0 |
| 9 | + norm_layer = get_norm_layer(norm_type = norm) |
| 10 | + |
| 11 | + if use_gpu: |
| 12 | + assert(torch.cuda.is_avaialable()) |
| 13 | + |
| 14 | + netG = ResnetX(id = 50, gpu_ids = gpu_ids , pretrain = pretrain) |
| 15 | + |
| 16 | + if len(gpu_ids) > 0: |
| 17 | + netG.cuda(gpu_ids[0]) |
| 18 | + |
| 19 | + if pretrain is True: |
| 20 | + print('Using pretrained weights') |
| 21 | + |
| 22 | + else |
| 23 | + print('Not using pretrained weights') |
| 24 | + |
| 25 | + init_weights(netG, init_type = init_type) |
| 26 | + |
| 27 | + return netG |
| 28 | + |
| 29 | +def define_D(which_model_netD,norm = 'batch',use_sigmoid = False,init_type = 'normal', gpu_ids = []): |
| 30 | + |
| 31 | + netD = None |
| 32 | + use_gpu = len(gpu_ids) > 0 |
| 33 | + norm_layer = get_norm_layer(norm_type = norm) |
| 34 | + |
| 35 | + if use_gpu: |
| 36 | + assert(torch.cuda.is_available()) |
| 37 | + |
| 38 | + netD = NLayerDiscriminator(4,64,n_layers = 3,norm_layer = norm_layer,use_sigmoid = use_sigmoid, gpu_ids = gpu_ids) |
| 39 | + |
| 40 | + if use_gpu: |
| 41 | + netD.cuda(gpu_ids[0]) |
| 42 | + |
| 43 | + init_weights(netD,init_type = init_type) |
| 44 | + return netD |
| 45 | + |
| 46 | + |
| 47 | +class GANLoss(nn.Module): |
| 48 | + |
| 49 | + def __init__(self,target_real_label = 1.0, target_fake_label = 0.0,tensor = torch.FloatTensor): |
| 50 | + |
| 51 | + super(GANLoss,self).__init__() |
| 52 | + self.real_label = target_real_label |
| 53 | + self.fake_label = target_fake_label |
| 54 | + self.fake_label_var = None |
| 55 | + self.real_label_var = None |
| 56 | + |
| 57 | + self.Tensor = tensor |
| 58 | + |
| 59 | + self.loss = nn.BCELoss() |
| 60 | + |
| 61 | + def get_target_tensor(self,input,target_is_real): |
| 62 | + |
| 63 | + target_tensor = None |
| 64 | + if target_is_real: |
| 65 | + create_label = ((self.real_label_var is None) or self.real_label_var.numel() !=input.numel()) |
| 66 | + |
| 67 | + if create_label : |
| 68 | + real_tensor = self.Tensor(input.size()).fill(self.real_label) |
| 69 | + self.real_label_var = Variable(real_tensor,requires_grad = False) |
| 70 | + target_tensor = self.real_label_var |
| 71 | + |
| 72 | + else: |
| 73 | + |
| 74 | + create_label = ((self.fake_label_var is None) or self.fake_label_var.numel() != input.numel()) |
| 75 | + |
| 76 | + if create_label : |
| 77 | + fake_tensor = self.Tensor(input.size()).fill_(self.real_label) |
| 78 | + self.fake_label_var = Variable(fake_tensor,requires_grad = False) |
| 79 | + target_tensor = self.fake_label_var |
| 80 | + |
| 81 | + return target_tensor |
| 82 | + |
| 83 | +class AlphaPredicitionLoss(nn.Module): |
| 84 | + |
| 85 | + def __init__(self): |
| 86 | + |
| 87 | + super(AlphaPredictionLoss,self).__init__() |
| 88 | + |
| 89 | + def forward(self,input,target,trimap): |
| 90 | + |
| 91 | + #trimap region , 1 in unknown region, 0 in known regions |
| 92 | + trimap_weights = torch.where(torch.eq(torch.ge(trimap,0.4),torch.le(trimap,0.6)),torch.ones_like(trimap),torch.ones_like(trimap)) |
| 93 | + unknown_region_size = trimap_weights.sum() |
| 94 | + diff = torch.sqrt(torch.add(torch.pow(input-target,2),1e-12)) |
| 95 | + return torch.mul(diff,trimap_weights).sum() / unknown_region_size |
| 96 | + |
| 97 | + |
| 98 | +class CompLoss(nn.Module): |
| 99 | + |
| 100 | + def __init__(self): |
| 101 | + |
| 102 | + super(CompLoss,self).__init__() |
| 103 | + |
| 104 | + def forward(self,input,target,trimap,fg,bg): |
| 105 | + |
| 106 | + trimap_weights = torch.where(torch.eq(torch.ge(trimap,0.4),torch.le(trimap,0.6)),torch.ones_like(trimap),torch.ones_like(trimap)) |
| 107 | + unknown_region_size = trimap_weights.sum() |
| 108 | + |
| 109 | + comp_target = torch.mul(target,fg) + torch.mul((1.0 - target),bg) |
| 110 | + comp_input = torch.mul(input,fg) + torch.mul((1.0 - input),bg) |
| 111 | + |
| 112 | + diff = torch.sqrt(torch.add(torch.pow(comp_input - comp_target, 2), 1e-12)) |
| 113 | + return torch.mul(diff, trimap_weights).sum() / unknown_region_size |
| 114 | + |
| 115 | + |
| 116 | +class ResNetX(nn.Module): |
| 117 | + |
| 118 | + def __init__(self,gpu_ids,pretrain): |
| 119 | + |
| 120 | + super(ResNetX,self).__init__() |
| 121 | + self.encoder = ResnetXEncoder(gpu_ids,pretrain) |
| 122 | + self.decoder = ResnetXDecoder(gpu_ids) |
| 123 | + |
| 124 | + def forward(self,input): |
| 125 | + |
| 126 | + #Encoder also gives us the saved pooling indices |
| 127 | + x,ind = self.encoder(input) |
| 128 | + x = self.decoder(x,ind) |
| 129 | + |
| 130 | + |
| 131 | + |
| 132 | +class ResnetXEncoder(nn.Module): |
| 133 | + #Encoder has the same structure as that of ResNet50, but the last 2 layers are removed |
| 134 | + def __init__(self,pretrain): |
| 135 | + super(ResnetXEncoder,self).__init__() |
| 136 | + |
| 137 | + resnet = models.resnet50(pretrained = pretrain) |
| 138 | + |
| 139 | + #Removing the last 2 Layers |
| 140 | + modules = list(resnet.children())[:-2] |
| 141 | + |
| 142 | + #to save the pooling indices |
| 143 | + for m in modules: |
| 144 | + if 'MaxPool' in m.__class__.__name__: |
| 145 | + m.return_indices = True |
| 146 | + |
| 147 | + # Change input shape of the first convolutional layer |
| 148 | + # Resnet had 3 channels, but for this task we need 4 channels, as we are also adding the trimap |
| 149 | + conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) |
| 150 | + weights = torch.zeros(64, 4, 7, 7) |
| 151 | + weights[:,:3,:,:] = modules[0].weight.data.view(64, 3, 7, 7) |
| 152 | + conv1.weight.data.copy_(weights) |
| 153 | + modules[0] = conv1 |
| 154 | + |
| 155 | + self.pool1 = nn.Sequential(*modules[: 4]) |
| 156 | + self.resnet = nn.Sequential(*modules[4:]) |
| 157 | + |
| 158 | + |
| 159 | + def forward(self,input): |
| 160 | + |
| 161 | + x,ind = self.pool1(input) |
| 162 | + x = self.resnet(x) |
| 163 | + |
| 164 | + return x,ind |
| 165 | + |
| 166 | + |
| 167 | + |
| 168 | +class ResnetXDecoder(nn.Module): |
| 169 | + def __init__(self, gpu_ids=[]): |
| 170 | + super(ResnetXDecoder, self).__init__() |
| 171 | + model = [nn.Conv2d(2048, 2048, kernel_size=1, padding=0), |
| 172 | + nn.BatchNorm2d(2048), |
| 173 | + nn.ReLU(True), |
| 174 | + nn.ConvTranspose2d(2048, 1024, kernel_size=1, stride=2, output_padding=1, bias=False), |
| 175 | + # nn.ConvTranspose2d(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), |
| 176 | + nn.BatchNorm2d(1024), |
| 177 | + nn.ReLU(True)] |
| 178 | + model += [nn.Conv2d(1024, 1024, kernel_size=5, padding=2), |
| 179 | + nn.BatchNorm2d(1024), |
| 180 | + nn.ReLU(True), |
| 181 | + nn.ConvTranspose2d(1024, 512, kernel_size=1, stride=2, output_padding=1, bias=False), |
| 182 | + # nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), |
| 183 | + nn.BatchNorm2d(512), |
| 184 | + nn.ReLU(True)] |
| 185 | + model += [nn.Conv2d(512, 512, kernel_size=5, padding=2), |
| 186 | + nn.BatchNorm2d(512), |
| 187 | + nn.ReLU(True), |
| 188 | + nn.ConvTranspose2d(512, 256, kernel_size=1, stride=2, output_padding=1, bias=False), |
| 189 | + # nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False), |
| 190 | + nn.BatchNorm2d(256), |
| 191 | + nn.ReLU(True)] |
| 192 | + model += [nn.Conv2d(256, 256, kernel_size=5, padding=2), |
| 193 | + nn.BatchNorm2d(256), |
| 194 | + nn.ReLU(True), |
| 195 | + nn.Conv2d(256, 64, kernel_size=1, stride=1, bias=False), |
| 196 | + nn.BatchNorm2d(64), |
| 197 | + nn.ReLU(True)] |
| 198 | + model += [nn.Conv2d(64, 64, kernel_size=5, padding=2), |
| 199 | + nn.BatchNorm2d(64), |
| 200 | + nn.ReLU(True)] |
| 201 | + self.model1 = nn.Sequential(*model) |
| 202 | + self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) |
| 203 | + model = [nn.Conv2d(64, 64, kernel_size=5, padding=2), |
| 204 | + nn.BatchNorm2d(64), |
| 205 | + nn.ReLU(True), |
| 206 | + nn.ConvTranspose2d(64, 64, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False), |
| 207 | + nn.BatchNorm2d(64), |
| 208 | + nn.ReLU(True)] |
| 209 | + model += [nn.Conv2d(64, 64, kernel_size=5, padding=2), |
| 210 | + nn.BatchNorm2d(64), |
| 211 | + nn.ReLU(True), |
| 212 | + nn.Conv2d(64, 1, kernel_size=5, padding=2), |
| 213 | + nn.Sigmoid()] |
| 214 | + self.model2 = nn.Sequential(*model) |
| 215 | + |
| 216 | + init_weights(self.model1, 'xavier') |
| 217 | + init_weights(self.model2, 'xavier') |
| 218 | + |
| 219 | + def forward(self, input, ind): |
| 220 | + x = self.model1(input) |
| 221 | + x = self.unpool(x, ind) |
| 222 | + x = self.model2(x) |
| 223 | + |
| 224 | +return x |
| 225 | + |
| 226 | + |
| 227 | + |
| 228 | + |
| 229 | + |
| 230 | + |
| 231 | + |
| 232 | + |
| 233 | + |
0 commit comments