From 381e1c096c661ab427488d2472c8391634b1d53a Mon Sep 17 00:00:00 2001 From: Peter Reutemann Date: Thu, 1 Sep 2022 15:27:33 +1200 Subject: [PATCH 1/2] model_from_checkpoint_path now uses os.path.join to concatenate checkpoints_path and _config.json to handle missing trailing slash in path --- keras_segmentation/predict.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/keras_segmentation/predict.py b/keras_segmentation/predict.py index 622b82adf..e2e6e9752 100644 --- a/keras_segmentation/predict.py +++ b/keras_segmentation/predict.py @@ -19,10 +19,8 @@ def model_from_checkpoint_path(checkpoints_path): - from .models.all_models import model_from_name - assert (os.path.isfile(checkpoints_path+"_config.json") - ), "Checkpoint not found." + assert (os.path.isfile(os.path.join(checkpoints_path, "_config.json"))), "Checkpoint not found." model_config = json.loads( open(checkpoints_path+"_config.json", "r").read()) latest_weights = find_latest_checkpoint(checkpoints_path) From 885884e618809f0a9d07ed732dab12052572c7aa Mon Sep 17 00:00:00 2001 From: Peter Reutemann Date: Thu, 1 Sep 2022 15:46:46 +1200 Subject: [PATCH 2/2] get_epoch_number_from_path now uses os.path.basename and also strips ".index" from the name (for TensorFlow 2) --- keras_segmentation/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras_segmentation/train.py b/keras_segmentation/train.py index 27e37f65d..41947c58b 100755 --- a/keras_segmentation/train.py +++ b/keras_segmentation/train.py @@ -6,16 +6,16 @@ import six from keras.callbacks import Callback from keras.callbacks import ModelCheckpoint -import tensorflow as tf import glob import sys + def find_latest_checkpoint(checkpoints_path, fail_safe=True): # This is legacy code, there should always be a "checkpoint" file in your directory def get_epoch_number_from_path(path): - return path.replace(checkpoints_path, "").strip(".") + return os.path.basename(path).replace(".index", "").strip(".") # Get all matching files all_checkpoint_files = glob.glob(checkpoints_path + ".*") @@ -41,6 +41,7 @@ def get_epoch_number_from_path(path): return latest_epoch_checkpoint + def masked_categorical_crossentropy(gt, pr): from keras.losses import categorical_crossentropy mask = 1 - gt[:, :, 0]