Skip to content

Commit ccce523

Browse files
committed
Add files via upload
Update README.md Add files via upload Delete epoch132_gt.png Delete epoch132_img.png Delete epoch132_pred.png Add files via upload Update README.md Update README.md Update README.md Update README.md Update README.md
1 parent 0915b7e commit ccce523

33 files changed

+712
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import random
2+
import numpy as np
3+
import torch
4+
from torch.autograd import Variable
5+
6+
7+
class ImagePool():
8+
"""
9+
This class implements an image buffer that stores previously generated images.
10+
This buffer enables us to update discriminators using a history of generated images
11+
rather than the ones produced by the latest generators.
12+
"""
13+
14+
def __init__(self, pool_size):
15+
#pool_size is the size of image buffer , if pool_size = 0, no buffer will be created
16+
self.pool_size = pool_size
17+
if self.pool_size > 0:
18+
self.num_imgs = 0
19+
self.images = []
20+
21+
def query(self, images):
22+
if self.pool_size == 0:
23+
return Variable(images)
24+
return_images = []
25+
for image in images:
26+
image = torch.unsqueeze(image, 0)
27+
if self.num_imgs < self.pool_size: #If the buffer is not full, keep inserting images into it
28+
self.num_imgs = self.num_imgs + 1
29+
self.images.append(image)
30+
return_images.append(image)
31+
else:
32+
p = random.uniform(0, 1)
33+
if p > 0.5: #50% chance, buffer will return a previously stored image and insert the current image into buffer
34+
random_id = random.randint(0, self.pool_size-1)
35+
tmp = self.images[random_id].clone()
36+
self.images[random_id] = image
37+
return_images.append(tmp)
38+
else:
39+
return_images.append(image) #Another 50% chance, will return the current image
40+
return_images = Variable(torch.cat(return_images, 0)) # collect all images and return
41+
return return_images
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
def define_G(which_model_netG, norm ='batch', init_type = 'normal',gpu_ids = [], pretrain = True):
6+
7+
netG = None
8+
use_gpu = len(gpu_ids) > 0
9+
norm_layer = get_norm_layer(norm_type = norm)
10+
11+
if use_gpu:
12+
assert(torch.cuda.is_avaialable())
13+
14+
netG = ResnetX(id = 50, gpu_ids = gpu_ids , pretrain = pretrain)
15+
16+
if len(gpu_ids) > 0:
17+
netG.cuda(gpu_ids[0])
18+
19+
if pretrain is True:
20+
print('Using pretrained weights')
21+
22+
else
23+
print('Not using pretrained weights')
24+
25+
init_weights(netG, init_type = init_type)
26+
27+
return netG
28+
29+
def define_D(which_model_netD,norm = 'batch',use_sigmoid = False,init_type = 'normal', gpu_ids = []):
30+
31+
netD = None
32+
use_gpu = len(gpu_ids) > 0
33+
norm_layer = get_norm_layer(norm_type = norm)
34+
35+
if use_gpu:
36+
assert(torch.cuda.is_available())
37+
38+
netD = NLayerDiscriminator(4,64,n_layers = 3,norm_layer = norm_layer,use_sigmoid = use_sigmoid, gpu_ids = gpu_ids)
39+
40+
if use_gpu:
41+
netD.cuda(gpu_ids[0])
42+
43+
init_weights(netD,init_type = init_type)
44+
return netD
45+
46+
47+
class GANLoss(nn.Module):
48+
49+
def __init__(self,target_real_label = 1.0, target_fake_label = 0.0,tensor = torch.FloatTensor):
50+
51+
super(GANLoss,self).__init__()
52+
self.real_label = target_real_label
53+
self.fake_label = target_fake_label
54+
self.fake_label_var = None
55+
self.real_label_var = None
56+
57+
self.Tensor = tensor
58+
59+
self.loss = nn.BCELoss()
60+
61+
def get_target_tensor(self,input,target_is_real):
62+
63+
target_tensor = None
64+
if target_is_real:
65+
create_label = ((self.real_label_var is None) or self.real_label_var.numel() !=input.numel())
66+
67+
if create_label :
68+
real_tensor = self.Tensor(input.size()).fill(self.real_label)
69+
self.real_label_var = Variable(real_tensor,requires_grad = False)
70+
target_tensor = self.real_label_var
71+
72+
else:
73+
74+
create_label = ((self.fake_label_var is None) or self.fake_label_var.numel() != input.numel())
75+
76+
if create_label :
77+
fake_tensor = self.Tensor(input.size()).fill_(self.real_label)
78+
self.fake_label_var = Variable(fake_tensor,requires_grad = False)
79+
target_tensor = self.fake_label_var
80+
81+
return target_tensor
82+
83+
class AlphaPredicitionLoss(nn.Module):
84+
85+
def __init__(self):
86+
87+
super(AlphaPredictionLoss,self).__init__()
88+
89+
def forward(self,input,target,trimap):
90+
91+
#trimap region , 1 in unknown region, 0 in known regions
92+
trimap_weights = torch.where(torch.eq(torch.ge(trimap,0.4),torch.le(trimap,0.6)),torch.ones_like(trimap),torch.ones_like(trimap))
93+
unknown_region_size = trimap_weights.sum()
94+
diff = torch.sqrt(torch.add(torch.pow(input-target,2),1e-12))
95+
return torch.mul(diff,trimap_weights).sum() / unknown_region_size
96+
97+
98+
class CompLoss(nn.Module):
99+
100+
def __init__(self):
101+
102+
super(CompLoss,self).__init__()
103+
104+
def forward(self,input,target,trimap,fg,bg):
105+
106+
trimap_weights = torch.where(torch.eq(torch.ge(trimap,0.4),torch.le(trimap,0.6)),torch.ones_like(trimap),torch.ones_like(trimap))
107+
unknown_region_size = trimap_weights.sum()
108+
109+
comp_target = torch.mul(target,fg) + torch.mul((1.0 - target),bg)
110+
comp_input = torch.mul(input,fg) + torch.mul((1.0 - input),bg)
111+
112+
diff = torch.sqrt(torch.add(torch.pow(comp_input - comp_target, 2), 1e-12))
113+
return torch.mul(diff, trimap_weights).sum() / unknown_region_size
114+
115+
116+
class ResNetX(nn.Module):
117+
118+
def __init__(self,gpu_ids,pretrain):
119+
120+
super(ResNetX,self).__init__()
121+
self.encoder = ResnetXEncoder(gpu_ids,pretrain)
122+
self.decoder = ResnetXDecoder(gpu_ids)
123+
124+
def forward(self,input):
125+
126+
#Encoder also gives us the saved pooling indices
127+
x,ind = self.encoder(input)
128+
x = self.decoder(x,ind)
129+
130+
131+
132+
class ResnetXEncoder(nn.Module):
133+
#Encoder has the same structure as that of ResNet50, but the last 2 layers are removed
134+
def __init__(self,pretrain):
135+
super(ResnetXEncoder,self).__init__()
136+
137+
resnet = models.resnet50(pretrained = pretrain)
138+
139+
#Removing the last 2 Layers
140+
modules = list(resnet.children())[:-2]
141+
142+
#to save the pooling indices
143+
for m in modules:
144+
if 'MaxPool' in m.__class__.__name__:
145+
m.return_indices = True
146+
147+
# Change input shape of the first convolutional layer
148+
# Resnet had 3 channels, but for this task we need 4 channels, as we are also adding the trimap
149+
conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
150+
weights = torch.zeros(64, 4, 7, 7)
151+
weights[:,:3,:,:] = modules[0].weight.data.view(64, 3, 7, 7)
152+
conv1.weight.data.copy_(weights)
153+
modules[0] = conv1
154+
155+
self.pool1 = nn.Sequential(*modules[: 4])
156+
self.resnet = nn.Sequential(*modules[4:])
157+
158+
159+
def forward(self,input):
160+
161+
x,ind = self.pool1(input)
162+
x = self.resnet(x)
163+
164+
return x,ind
165+
166+
167+
168+
class ResnetXDecoder(nn.Module):
169+
def __init__(self, gpu_ids=[]):
170+
super(ResnetXDecoder, self).__init__()
171+
model = [nn.Conv2d(2048, 2048, kernel_size=1, padding=0),
172+
nn.BatchNorm2d(2048),
173+
nn.ReLU(True),
174+
nn.ConvTranspose2d(2048, 1024, kernel_size=1, stride=2, output_padding=1, bias=False),
175+
# nn.ConvTranspose2d(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
176+
nn.BatchNorm2d(1024),
177+
nn.ReLU(True)]
178+
model += [nn.Conv2d(1024, 1024, kernel_size=5, padding=2),
179+
nn.BatchNorm2d(1024),
180+
nn.ReLU(True),
181+
nn.ConvTranspose2d(1024, 512, kernel_size=1, stride=2, output_padding=1, bias=False),
182+
# nn.ConvTranspose2d(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
183+
nn.BatchNorm2d(512),
184+
nn.ReLU(True)]
185+
model += [nn.Conv2d(512, 512, kernel_size=5, padding=2),
186+
nn.BatchNorm2d(512),
187+
nn.ReLU(True),
188+
nn.ConvTranspose2d(512, 256, kernel_size=1, stride=2, output_padding=1, bias=False),
189+
# nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
190+
nn.BatchNorm2d(256),
191+
nn.ReLU(True)]
192+
model += [nn.Conv2d(256, 256, kernel_size=5, padding=2),
193+
nn.BatchNorm2d(256),
194+
nn.ReLU(True),
195+
nn.Conv2d(256, 64, kernel_size=1, stride=1, bias=False),
196+
nn.BatchNorm2d(64),
197+
nn.ReLU(True)]
198+
model += [nn.Conv2d(64, 64, kernel_size=5, padding=2),
199+
nn.BatchNorm2d(64),
200+
nn.ReLU(True)]
201+
self.model1 = nn.Sequential(*model)
202+
self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
203+
model = [nn.Conv2d(64, 64, kernel_size=5, padding=2),
204+
nn.BatchNorm2d(64),
205+
nn.ReLU(True),
206+
nn.ConvTranspose2d(64, 64, kernel_size=7, stride=2, padding=3, output_padding=1, bias=False),
207+
nn.BatchNorm2d(64),
208+
nn.ReLU(True)]
209+
model += [nn.Conv2d(64, 64, kernel_size=5, padding=2),
210+
nn.BatchNorm2d(64),
211+
nn.ReLU(True),
212+
nn.Conv2d(64, 1, kernel_size=5, padding=2),
213+
nn.Sigmoid()]
214+
self.model2 = nn.Sequential(*model)
215+
216+
init_weights(self.model1, 'xavier')
217+
init_weights(self.model2, 'xavier')
218+
219+
def forward(self, input, ind):
220+
x = self.model1(input)
221+
x = self.unpool(x, ind)
222+
x = self.model2(x)
223+
224+
return x
225+
226+
227+
228+
229+
230+
231+
232+
233+
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import argparse
2+
import os
3+
from util import util
4+
import torch
5+
6+
7+
class BaseOptions():
8+
def __init__(self):
9+
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10+
self.initialized = False
11+
12+
def initialize(self):
13+
self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
14+
self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
15+
self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
16+
self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
17+
18+
19+
self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
20+
self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG')
21+
self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
22+
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
23+
self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
24+
self.parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single]')
25+
self.parser.add_argument('--model', type=str, default='cycle_gan',
26+
help='chooses which model to use. cycle_gan, pix2pix, test')
27+
self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
28+
self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
29+
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
30+
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
31+
self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
32+
33+
self.initialized = True
34+
35+
def parse(self):
36+
if not self.initialized:
37+
self.initialize()
38+
self.opt = self.parser.parse_args()
39+
self.opt.isTrain = self.isTrain # train or test
40+
41+
str_ids = self.opt.gpu_ids.split(',')
42+
self.opt.gpu_ids = []
43+
for str_id in str_ids:
44+
id = int(str_id)
45+
if id >= 0:
46+
self.opt.gpu_ids.append(id)
47+
48+
# set gpu ids
49+
if len(self.opt.gpu_ids) > 0:
50+
torch.cuda.set_device(self.opt.gpu_ids[0])
51+
52+
args = vars(self.opt)
53+
54+
print('------------ Options -------------')
55+
for k, v in sorted(args.items()):
56+
print('%s: %s' % (str(k), str(v)))
57+
print('-------------- End ----------------')
58+
59+
# save to the disk
60+
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
61+
util.mkdirs(expr_dir)
62+
file_name = os.path.join(expr_dir, 'opt.txt')
63+
with open(file_name, 'wt') as opt_file:
64+
opt_file.write('------------ Options -------------\n')
65+
for k, v in sorted(args.items()):
66+
opt_file.write('%s: %s\n' % (str(k), str(v)))
67+
opt_file.write('-------------- End ----------------\n')
68+
return self.opt
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from .base_options import BaseOptions
2+
3+
4+
class TrainOptions(BaseOptions):
5+
def initialize(self):
6+
BaseOptions.initialize(self)
7+
8+
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
9+
self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
10+
self.parser.add_argument('--save_epoch_freq', type=int, default=50, help='frequency of saving checkpoints at the end of epochs')
11+
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
12+
self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
13+
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
14+
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
15+
self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
16+
self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
17+
18+
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
19+
20+
21+
self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
22+
23+
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
24+
self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
25+
26+
self.parser.add_argument('--no_pretrain', action='store_true', help='initialize weights with pretrained imagenet weights')
27+
28+
Loading
Loading
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)