Skip to content
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ checkpoint
*.npy
model.ckpt-*

snapshots/*
*.sh
14 changes: 11 additions & 3 deletions network.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ def create_session(self):
self.sess.run([global_init, local_init])

def restore(self, data_path, var_list=None):
if var_list is None:
var_list = tf.global_variables()
if data_path.endswith('.npy'):
self.load_npy(data_path, self.sess)
self.load_npy(data_path, self.sess, var_list=var_list)
else:
loader = tf.train.Saver(var_list=tf.global_variables())
loader = tf.train.Saver(var_list=var_list)
loader.restore(self.sess, data_path)

print('Restore from {}'.format(data_path))
Expand All @@ -92,13 +94,16 @@ def save(self, saver, save_dir, step):
print('The checkpoint has been created, step: {}'.format(step))

## Restore from .npy
def load_npy(self, data_path, session, ignore_missing=False):
def load_npy(self, data_path, session, ignore_missing=False, var_list=None):
'''Load network weights.
data_path: The path to the numpy-serialized network weights
session: The current TensorFlow session
ignore_missing: If true, serialized weights for missing layers are ignored.
'''
if var_list is None:
var_list = tf.global_variables()
data_dict = np.load(data_path, encoding='latin1').item()
var_names = [v.name for v in var_list]
for op_name in data_dict:
with tf.variable_scope(op_name, reuse=True):
for param_name, data in data_dict[op_name].items():
Expand All @@ -107,6 +112,9 @@ def load_npy(self, data_path, session, ignore_missing=False):
param_name = BN_param_map[param_name]

var = tf.get_variable(param_name)
if var.name not in var_names:
print("Not restored: %s" % var.name)
continue
session.run(var.assign(data))
except ValueError:
if not ignore_missing:
Expand Down
15 changes: 12 additions & 3 deletions utils/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _random_crop_and_pad_image_and_labels(image, label, crop_h, crop_w, ignore_l
combined_crop = tf.random_crop(combined_pad, [crop_h, crop_w, 4])
img_crop = combined_crop[:, :, :last_image_dim]
label_crop = combined_crop[:, :, last_image_dim:]
label_crop = label_crop + ignore_label
label_crop = tf.cast(label_crop, dtype=tf.uint8)

# Set static shape so that tensorflow knows shape at compile time.
Expand Down Expand Up @@ -133,10 +134,18 @@ def _infer_preprocess(img, swap_channel=False):

return img, o_shape, n_shape

def _eval_preprocess(img, label, shape, dataset):
if dataset == 'cityscapes':
def _eval_preprocess(img, label, shape, dataset, ignore_label=255):
if 'citycapes' in dataset:
img = tf.image.pad_to_bounding_box(img, 0, 0, shape[0], shape[1])
img.set_shape([shape[0], shape[1], 3])

label = tf.cast(label, dtype=tf.float32)
label = label - ignore_label # Needs to be subtracted and later added due to 0 padding.
label = tf.image.pad_to_bounding_box(label, 0, 0, shape[0], shape[1])
label = label + ignore_label
label = tf.cast(label, dtype=tf.uint8)
label.set_shape([shape[0], shape[1], 1])

else:
img = tf.image.resize_images(img, shape, align_corners=True)

Expand Down Expand Up @@ -178,7 +187,7 @@ def create_tf_dataset(self, cfg):

else: # Evaluation phase
dataset = dataset.map(lambda x, y:
_eval_preprocess(x, y, cfg.param['eval_size'], cfg.dataset),
_eval_preprocess(x, y, cfg.param['eval_size'], cfg.dataset, cfg.param['ignore_label']),
num_parallel_calls=cfg.N_WORKERS)
dataset = dataset.batch(1)

Expand Down