diff --git a/keras_segmentation/predict.py b/keras_segmentation/predict.py index 622b82adf..e2e6e9752 100644 --- a/keras_segmentation/predict.py +++ b/keras_segmentation/predict.py @@ -19,10 +19,8 @@ def model_from_checkpoint_path(checkpoints_path): - from .models.all_models import model_from_name - assert (os.path.isfile(checkpoints_path+"_config.json") - ), "Checkpoint not found." + assert (os.path.isfile(os.path.join(checkpoints_path, "_config.json"))), "Checkpoint not found." model_config = json.loads( open(checkpoints_path+"_config.json", "r").read()) latest_weights = find_latest_checkpoint(checkpoints_path) diff --git a/keras_segmentation/train.py b/keras_segmentation/train.py index 27e37f65d..41947c58b 100755 --- a/keras_segmentation/train.py +++ b/keras_segmentation/train.py @@ -6,16 +6,16 @@ import six from keras.callbacks import Callback from keras.callbacks import ModelCheckpoint -import tensorflow as tf import glob import sys + def find_latest_checkpoint(checkpoints_path, fail_safe=True): # This is legacy code, there should always be a "checkpoint" file in your directory def get_epoch_number_from_path(path): - return path.replace(checkpoints_path, "").strip(".") + return os.path.basename(path).replace(".index", "").strip(".") # Get all matching files all_checkpoint_files = glob.glob(checkpoints_path + ".*") @@ -41,6 +41,7 @@ def get_epoch_number_from_path(path): return latest_epoch_checkpoint + def masked_categorical_crossentropy(gt, pr): from keras.losses import categorical_crossentropy mask = 1 - gt[:, :, 0]