From 172a98fcd72527941cf81d6a1efef1e7ff472d2d Mon Sep 17 00:00:00 2001 From: Sangwon Lee Date: Fri, 7 Dec 2018 16:07:45 +0900 Subject: [PATCH 1/7] fix image_reader --- utils/image_reader.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/utils/image_reader.py b/utils/image_reader.py index 3d7a65e..83edda4 100644 --- a/utils/image_reader.py +++ b/utils/image_reader.py @@ -100,6 +100,7 @@ def _random_crop_and_pad_image_and_labels(image, label, crop_h, crop_w, ignore_l combined_crop = tf.random_crop(combined_pad, [crop_h, crop_w, 4]) img_crop = combined_crop[:, :, :last_image_dim] label_crop = combined_crop[:, :, last_image_dim:] + label_crop = label_crop + ignore_label label_crop = tf.cast(label_crop, dtype=tf.uint8) # Set static shape so that tensorflow knows shape at compile time. @@ -133,10 +134,18 @@ def _infer_preprocess(img, swap_channel=False): return img, o_shape, n_shape -def _eval_preprocess(img, label, shape, dataset): - if dataset == 'cityscapes': +def _eval_preprocess(img, label, shape, dataset, ignore_label=255): + if dataset in ['cityscapes', 'cityscapes-mini']: img = tf.image.pad_to_bounding_box(img, 0, 0, shape[0], shape[1]) img.set_shape([shape[0], shape[1], 3]) + + label = tf.cast(label, dtype=tf.float32) + label = label - ignore_label # Needs to be subtracted and later added due to 0 padding. + label = tf.image.pad_to_bounding_box(label, 0, 0, shape[0], shape[1]) + label = label + ignore_label + label = tf.cast(label, dtype=tf.uint8) + label.set_shape([shape[0], shape[1], 1]) + else: img = tf.image.resize_images(img, shape, align_corners=True) @@ -178,7 +187,7 @@ def create_tf_dataset(self, cfg): else: # Evaluation phase dataset = dataset.map(lambda x, y: - _eval_preprocess(x, y, cfg.param['eval_size'], cfg.dataset), + _eval_preprocess(x, y, cfg.param['eval_size'], cfg.dataset, cfg.param['ignore_label']), num_parallel_calls=cfg.N_WORKERS) dataset = dataset.batch(1) From 986f626e7291cae3196604995333f71d576a7327 Mon Sep 17 00:00:00 2001 From: Sangwon Lee Date: Fri, 7 Dec 2018 16:27:33 +0900 Subject: [PATCH 2/7] rename network.py to networks.py --- network.py => networks.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename network.py => networks.py (100%) diff --git a/network.py b/networks.py similarity index 100% rename from network.py rename to networks.py From b4397596431d65349490a60d1d365a4cb674a70f Mon Sep 17 00:00:00 2001 From: Sangwon Lee Date: Mon, 3 Dec 2018 13:46:31 +0900 Subject: [PATCH 3/7] fix the code to work var_list while restore parameters --- networks.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/networks.py b/networks.py index 1ea5939..716c18d 100644 --- a/networks.py +++ b/networks.py @@ -72,11 +72,11 @@ def create_session(self): self.sess = tf.Session(config=config) self.sess.run([global_init, local_init]) - def restore(self, data_path, var_list=None): + def restore(self, data_path, var_list=tf.global_variables()): if data_path.endswith('.npy'): - self.load_npy(data_path, self.sess) + self.load_npy(data_path, self.sess, var_list=var_list) else: - loader = tf.train.Saver(var_list=tf.global_variables()) + loader = tf.train.Saver(var_list=var_list) loader.restore(self.sess, data_path) print('Restore from {}'.format(data_path)) @@ -92,14 +92,17 @@ def save(self, saver, save_dir, step): print('The checkpoint has been created, step: {}'.format(step)) ## Restore from .npy - def load_npy(self, data_path, session, ignore_missing=False): + def load_npy(self, data_path, session, ignore_missing=False, var_list=tf.global_variables()): '''Load network weights. data_path: The path to the numpy-serialized network weights session: The current TensorFlow session ignore_missing: If true, serialized weights for missing layers are ignored. ''' data_dict = np.load(data_path, encoding='latin1').item() + var_names = [v.name for v in var_list] for op_name in data_dict: + if op_name not in var_names: + continue with tf.variable_scope(op_name, reuse=True): for param_name, data in data_dict[op_name].items(): try: From 73fb9f3308609f01f49b3eb1bfce419b07ebe24f Mon Sep 17 00:00:00 2001 From: Sangwon Lee Date: Wed, 5 Dec 2018 18:31:34 +0900 Subject: [PATCH 4/7] fix the code to work var_list while restore parameters --- networks.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/networks.py b/networks.py index 716c18d..7deb67a 100644 --- a/networks.py +++ b/networks.py @@ -72,7 +72,9 @@ def create_session(self): self.sess = tf.Session(config=config) self.sess.run([global_init, local_init]) - def restore(self, data_path, var_list=tf.global_variables()): + def restore(self, data_path, var_list=None): + if var_list is None: + var_list = tf.global_variables() if data_path.endswith('.npy'): self.load_npy(data_path, self.sess, var_list=var_list) else: @@ -92,17 +94,17 @@ def save(self, saver, save_dir, step): print('The checkpoint has been created, step: {}'.format(step)) ## Restore from .npy - def load_npy(self, data_path, session, ignore_missing=False, var_list=tf.global_variables()): + def load_npy(self, data_path, session, ignore_missing=False, var_list=None): '''Load network weights. data_path: The path to the numpy-serialized network weights session: The current TensorFlow session ignore_missing: If true, serialized weights for missing layers are ignored. ''' + if var_list is None: + var_list = tf.global_variables() data_dict = np.load(data_path, encoding='latin1').item() var_names = [v.name for v in var_list] for op_name in data_dict: - if op_name not in var_names: - continue with tf.variable_scope(op_name, reuse=True): for param_name, data in data_dict[op_name].items(): try: @@ -110,6 +112,9 @@ def load_npy(self, data_path, session, ignore_missing=False, var_list=tf.global_ param_name = BN_param_map[param_name] var = tf.get_variable(param_name) + if var.name not in var_names: + print("Not restored: %s" % var.name) + continue session.run(var.assign(data)) except ValueError: if not ignore_missing: From f03547a3147ab495cf85dd1690ee6422042b269a Mon Sep 17 00:00:00 2001 From: Sangwon Lee Date: Fri, 7 Dec 2018 16:30:44 +0900 Subject: [PATCH 5/7] rename networks.py to network.py --- networks.py => network.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename networks.py => network.py (100%) diff --git a/networks.py b/network.py similarity index 100% rename from networks.py rename to network.py From 7f861ac235924de3db2025f9437c1c3171d56e00 Mon Sep 17 00:00:00 2001 From: Sangwon Lee Date: Fri, 7 Dec 2018 18:32:21 +0900 Subject: [PATCH 6/7] update gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 6333025..d45e338 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ checkpoint *.npy model.ckpt-* +snapshots/* +*.sh \ No newline at end of file From c524097294a906aa764a212b6a94c327dbe7dda7 Mon Sep 17 00:00:00 2001 From: Sangwon Lee Date: Tue, 11 Dec 2018 23:36:39 +0900 Subject: [PATCH 7/7] Remove code for personal use --- utils/image_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/image_reader.py b/utils/image_reader.py index 83edda4..ec7d623 100644 --- a/utils/image_reader.py +++ b/utils/image_reader.py @@ -135,7 +135,7 @@ def _infer_preprocess(img, swap_channel=False): return img, o_shape, n_shape def _eval_preprocess(img, label, shape, dataset, ignore_label=255): - if dataset in ['cityscapes', 'cityscapes-mini']: + if 'citycapes' in dataset: img = tf.image.pad_to_bounding_box(img, 0, 0, shape[0], shape[1]) img.set_shape([shape[0], shape[1], 3])