Skip to content

Main #408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open

Main #408

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions keras_segmentation/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,31 @@
from time import time

from .train import find_latest_checkpoint
from .data_utils.data_loader import get_image_array, get_segmentation_array,\
from .data_utils.data_loader import get_image_array, get_segmentation_array, \
DATA_LOADER_SEED, class_colors, get_pairs_from_paths
from .models.config import IMAGE_ORDERING


random.seed(DATA_LOADER_SEED)


def model_from_checkpoint_path(checkpoints_path):
def load_segmentation_model(model_config: dict, weights: str):
from .models.all_models import model_from_name

model = model_from_name[model_config['model_class']](
model_config['n_classes'], input_height=model_config['input_height'],
input_width=model_config['input_width'])
print("loaded weights ", weights)
status = model.load_weights(weights)

if status is not None:
status.expect_partial()

return model


def model_from_checkpoint_path(checkpoints_path):
from .models.all_models import model_from_name

assert (os.path.isfile(checkpoints_path+"_config.json")
), "Checkpoint not found."
model_config = json.loads(
Expand Down Expand Up @@ -76,7 +90,8 @@ def get_legends(class_names, colors=class_colors):
def overlay_seg_image(inp_img, seg_img):
orininal_h = inp_img.shape[0]
orininal_w = inp_img.shape[1]
seg_img = cv2.resize(seg_img, (orininal_w, orininal_h), interpolation=cv2.INTER_NEAREST)
seg_img = cv2.resize(seg_img, (orininal_w, orininal_h),
interpolation=cv2.INTER_NEAREST)

fused_img = (inp_img/2 + seg_img/2).astype('uint8')
return fused_img
Expand Down Expand Up @@ -108,10 +123,12 @@ def visualize_segmentation(seg_arr, inp_img=None, n_classes=None,
if inp_img is not None:
original_h = inp_img.shape[0]
original_w = inp_img.shape[1]
seg_img = cv2.resize(seg_img, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
seg_img = cv2.resize(seg_img, (original_w, original_h),
interpolation=cv2.INTER_NEAREST)

if (prediction_height is not None) and (prediction_width is not None):
seg_img = cv2.resize(seg_img, (prediction_width, prediction_height), interpolation=cv2.INTER_NEAREST)
seg_img = cv2.resize(
seg_img, (prediction_width, prediction_height), interpolation=cv2.INTER_NEAREST)
if inp_img is not None:
inp_img = cv2.resize(inp_img,
(prediction_width, prediction_height))
Expand Down Expand Up @@ -139,13 +156,14 @@ def predict(model=None, inp=None, out_fname=None,
model = model_from_checkpoint_path(checkpoints_path)

assert (inp is not None)
assert ((type(inp) is np.ndarray) or isinstance(inp, six.string_types)),\
assert ((type(inp) is np.ndarray) or isinstance(inp, six.string_types)), \
"Input should be the CV image or the input file name"

if isinstance(inp, six.string_types):
inp = cv2.imread(inp, read_image_type)

assert (len(inp.shape) == 3 or len(inp.shape) == 1 or len(inp.shape) == 4), "Image should be h,w,3 "
assert (len(inp.shape) == 3 or len(inp.shape) ==
1 or len(inp.shape) == 4), "Image should be h,w,3 "

output_width = model.output_width
output_height = model.output_height
Expand Down Expand Up @@ -193,7 +211,6 @@ def predict_multiple(model=None, inps=None, inp_dir=None, out_dir=None,
if not os.path.exists(out_dir):
os.makedirs(out_dir)


for i, inp in enumerate(tqdm(inps)):
if out_dir is None:
out_fname = None
Expand Down Expand Up @@ -235,7 +252,7 @@ def predict_video(model=None, inp=None, output=None,
n_classes = model.n_classes

cap, video, fps = set_video(inp, output)
while(cap.isOpened()):
while (cap.isOpened()):
prev_time = time()
ret, frame = cap.read()
if frame is not None:
Expand All @@ -248,7 +265,7 @@ def predict_video(model=None, inp=None, output=None,
class_names=class_names,
prediction_width=prediction_width,
prediction_height=prediction_height
)
)
else:
break
print("FPS: {}".format(1/(time() - prev_time)))
Expand All @@ -268,14 +285,14 @@ def evaluate(model=None, inp_images=None, annotations=None,
inp_images_dir=None, annotations_dir=None, checkpoints_path=None, read_image_type=1):

if model is None:
assert (checkpoints_path is not None),\
"Please provide the model or the checkpoints_path"
assert (checkpoints_path is not None), \
"Please provide the model or the checkpoints_path"
model = model_from_checkpoint_path(checkpoints_path)

if inp_images is None:
assert (inp_images_dir is not None),\
"Please provide inp_images or inp_images_dir"
assert (annotations_dir is not None),\
assert (inp_images_dir is not None), \
"Please provide inp_images or inp_images_dir"
assert (annotations_dir is not None), \
"Please provide inp_images or inp_images_dir"

paths = get_pairs_from_paths(inp_images_dir, annotations_dir)
Expand Down
18 changes: 10 additions & 8 deletions keras_segmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import glob
import sys


def find_latest_checkpoint(checkpoints_path, fail_safe=True):

# This is legacy code, there should always be a "checkpoint" file in your directory
Expand Down Expand Up @@ -41,6 +42,7 @@ def get_epoch_number_from_path(path):

return latest_epoch_checkpoint


def masked_categorical_crossentropy(gt, pr):
from keras.losses import categorical_crossentropy
mask = 1 - gt[:, :, 0]
Expand Down Expand Up @@ -87,7 +89,7 @@ def train(model,
read_image_type=1 # cv2.IMREAD_COLOR = 1 (rgb),
# cv2.IMREAD_GRAYSCALE = 0,
# cv2.IMREAD_UNCHANGED = -1 (4 channels like RGBA)
):
):
from .models.all_models import model_from_name
# check if user gives model name instead of the model object
if isinstance(model, six.string_types):
Expand Down Expand Up @@ -124,7 +126,7 @@ def train(model,
config_file = checkpoints_path + "_config.json"
dir_name = os.path.dirname(config_file)

if ( not os.path.exists(dir_name) ) and len( dir_name ) > 0 :
if (not os.path.exists(dir_name)) and len(dir_name) > 0:
os.makedirs(dir_name)

with open(config_file, "w") as f:
Expand Down Expand Up @@ -179,14 +181,14 @@ def train(model,
other_inputs_paths=other_inputs_paths,
preprocessing=preprocessing, read_image_type=read_image_type)

if callbacks is None and (not checkpoints_path is None) :
if callbacks is None and (not checkpoints_path is None):
default_callback = ModelCheckpoint(
filepath=checkpoints_path + ".{epoch:05d}",
save_weights_only=True,
verbose=True
)
filepath=checkpoints_path + ".{epoch:05d}" + ".weights.h5",
save_weights_only=True,
verbose=True
)

if sys.version_info[0] < 3: # for pyhton 2
if sys.version_info[0] < 3: # for pyhton 2
default_callback = CheckpointsCallback(checkpoints_path)

callbacks = [
Expand Down
36 changes: 18 additions & 18 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
cv_ver = ""
keras_ver = ">=2.0.0"
if sys.version_info.major < 3:
cv_ver = "<=4.2.0.32"
keras_ver = "<=2.3.0"
cv_ver = "<=4.2.0.32"
keras_ver = "<=2.3.0"


setup(name="keras_segmentation",
Expand All @@ -19,23 +19,23 @@
url="https://github.com/divamgupta/image-segmentation-keras",
packages=find_packages(exclude=["test"]),
entry_points={
'console_scripts': [
'keras_segmentation = keras_segmentation.__main__:main'
]
'console_scripts': [
'keras_segmentation = keras_segmentation.__main__:main'
]
},
install_requires=[
"h5py<=2.10.0",
"Keras"+keras_ver,
"imageio==2.5.0",
"imgaug>=0.4.0",
"opencv-python"+cv_ver,
"tqdm"],
"h5py",
"Keras",
"imageio==2.5.0",
"imgaug>=0.4.0",
"opencv-python",
"tqdm"],
extras_require={
# These requires provide different backends available with Keras
"tensorflow": ["tensorflow"],
"cntk": ["cntk"],
"theano": ["theano"],
# Default testing with tensorflow
"tests-default": ["tensorflow", "pytest"]
# These requires provide different backends available with Keras
"tensorflow": ["tensorflow"],
"cntk": ["cntk"],
"theano": ["theano"],
# Default testing with tensorflow
"tests-default": ["tensorflow", "pytest"]
}
)
)