diff --git a/README.md b/README.md index c31213855..172979aad 100644 --- a/README.md +++ b/README.md @@ -6,14 +6,14 @@ [![MIT license](https://img.shields.io/badge/License-MIT-blue.svg)](http://perso.crans.org/besson/LICENSE.html) [![Twitter](https://img.shields.io/twitter/url.svg?label=Follow%20%40divamgupta&style=social&url=https%3A%2F%2Ftwitter.com%2Fdivamgupta)](https://twitter.com/divamgupta) - - Implementation of various Deep Image Segmentation models in keras. -Link to the full blog post with tutorial : https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html +Link to the full blog post with tutorial: https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html

- + This image ilustrates a general convolutional architecture for image segmentation: its input is an image showing 
+   a cat and a dog and its output is the segmented image.

## Working Google Colab Examples: @@ -62,19 +62,19 @@ Following models are supported: Example results for the pre-trained models provided : -Input Image | Output Segmentation Image -:-------------------------:|:-------------------------: -![](sample_images/1_input.jpg) | ![](sample_images/1_output.png) -![](sample_images/3_input.jpg) | ![](sample_images/3_output.png) +Input Image | Output Segmentation Image +:--------------------------------------------------:|:---------------------------------------------------------: +![indoor bedroom scene](sample_images/1_input.jpg) | ![indoor bedroom segmented](sample_images/1_output.png) +![outdoor house](sample_images/3_input.jpg) | ![outdoor house segmented](sample_images/3_output.png) ## Getting Started ### Prerequisites -* Keras ( recommended version : 2.4.3 ) +* Keras (recommended version: 2.4.3) * OpenCV for Python -* Tensorflow ( recommended version : 2.4.1 ) +* Tensorflow (recommended version: 2.4.1) ```shell apt-get install -y libsm6 libxext6 libxrender-dev @@ -90,13 +90,13 @@ Recommended way: pip install --upgrade git+https://github.com/divamgupta/image-segmentation-keras ``` -### or +#### or ```shell pip install keras-segmentation ``` -### or +#### or ```shell git clone https://github.com/divamgupta/image-segmentation-keras @@ -145,7 +145,7 @@ import cv2 import numpy as np ann_img = np.zeros((30,30,3)).astype('uint8') -ann_img[ 3 , 4 ] = 1 # this would set the label of pixel 3,4 as 1 +ann_img[ 3 , 4 ] = 1 # this would set the label of pixel 3, 4 as 1 cv2.imwrite( "ann_1.png" ,ann_img ) ``` @@ -171,9 +171,9 @@ from keras_segmentation.models.unet import vgg_unet model = vgg_unet(n_classes=51 , input_height=416, input_width=608 ) model.train( - train_images = "dataset1/images_prepped_train/", - train_annotations = "dataset1/annotations_prepped_train/", - checkpoints_path = "/tmp/vgg_unet_1" , epochs=5 + train_images="dataset1/images_prepped_train/", + train_annotations="dataset1/annotations_prepped_train/", + checkpoints_path="/tmp/vgg_unet_1", epochs=5 ) out = model.predict_segmentation( @@ -185,7 +185,8 @@ import matplotlib.pyplot as plt plt.imshow(out) # evaluating the model -print(model.evaluate_segmentation( inp_images_dir="dataset1/images_prepped_test/" , annotations_dir="dataset1/annotations_prepped_test/" ) ) +print(model.evaluate_segmentation(inp_images_dir="dataset1/images_prepped_test/", + annotations_dir="dataset1/annotations_prepped_test/" ) ) ``` @@ -292,14 +293,35 @@ new_model = pspnet_50( n_classes=51 ) transfer_weights( pretrained_model , new_model ) # transfer weights from pre-trained model to your model -new_model.train( - train_images = "dataset1/images_prepped_train/", - train_annotations = "dataset1/annotations_prepped_train/", - checkpoints_path = "/tmp/vgg_unet_1" , epochs=5 +history = new_model.train( + train_images="dataset1/images_prepped_train/", + train_annotations="dataset1/annotations_prepped_train/", + checkpoints_path="/tmp/vgg_unet_1", epochs=5 ) +``` +Note that `history` is a [History](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/History?) object which +contain logs from the training, so once the training finishes you can plot the loss and accuracy as usual: -``` +```python +import matplotlib.pyplot as plt + +fig, (ax0, ax1) = plt.subplots(1,2, figsize=(12,4)) + +ax0.plot(history.history['loss'], label='train loss') +ax0.set_title('Loss function') +ax0.set_xlabel('epochs') +ax0.set_ylabel('loss') +ax0.legend() + +ax1.plot(history.history['accuracy'], label='train acc') +ax1.set_title('Accuracy') +ax1.set_xlabel('epochs') +ax1.set_ylabel('loss') +ax1.legend() + +plt.show() +``` @@ -312,11 +334,11 @@ from keras_segmentation.predict import model_from_checkpoint_path from keras_segmentation.models.unet import unet_mini from keras_segmentation.model_compression import perform_distilation -model_large = model_from_checkpoint_path( "/checkpoints/path/of/trained/model" ) -model_small = unet_mini( n_classes=51, input_height=300, input_width=400 ) +model_large = model_from_checkpoint_path("/checkpoints/path/of/trained/model") +model_small = unet_mini(n_classes=51, input_height=300, input_width=400) -perform_distilation ( data_path="/path/to/large_image_set/" , checkpoints_path="path/to/save/checkpoints" , - teacher_model=model_large , student_model=model_small , distilation_loss='kl' , feats_distilation_loss='pa' ) +perform_distilation (data_path="/path/to/large_image_set/" , checkpoints_path="path/to/save/checkpoints" , + teacher_model=model_large, student_model=model_small, distilation_loss='kl', feats_distilation_loss='pa') ``` @@ -338,10 +360,10 @@ def custom_augmentation(): [ # apply the following augmenters to most images iaa.Fliplr(0.5), # horizontally flip 50% of all images - iaa.Flipud(0.5), # horizontally flip 50% of all images + iaa.Flipud(0.5) # vertically flip 50% of all images ]) -model = vgg_unet(n_classes=51 , input_height=416, input_width=608) +model = vgg_unet(n_classes=51, input_height=416, input_width=608) model.train( train_images = "dataset1/images_prepped_train/", @@ -364,9 +386,9 @@ model = vgg_unet(n_classes=51 , input_height=416, input_width=608, ) model.train( - train_images = "dataset1/images_prepped_train/", - train_annotations = "dataset1/annotations_prepped_train/", - checkpoints_path = "/tmp/vgg_unet_1" , epochs=5, + train_images="dataset1/images_prepped_train/", + train_annotations="dataset1/annotations_prepped_train/", + checkpoints_path="/tmp/vgg_unet_1" , epochs=5, read_image_type=0 # Sets how opencv will read the images # cv2.IMREAD_COLOR = 1 (rgb), # cv2.IMREAD_GRAYSCALE = 0, @@ -379,18 +401,17 @@ model.train( The following example shows how to set a custom image preprocessing function. ```python - from keras_segmentation.models.unet import vgg_unet def image_preprocessing(image): return image + 1 -model = vgg_unet(n_classes=51 , input_height=416, input_width=608) +model = vgg_unet(n_classes=51, input_height=416, input_width=608) model.train( - train_images = "dataset1/images_prepped_train/", - train_annotations = "dataset1/annotations_prepped_train/", - checkpoints_path = "/tmp/vgg_unet_1" , epochs=5, + train_images="dataset1/images_prepped_train/", + train_annotations="dataset1/annotations_prepped_train/", + checkpoints_path="/tmp/vgg_unet_1", epochs=5, preprocessing=image_preprocessing # Sets the preprocessing function ) ``` @@ -400,14 +421,13 @@ model.train( The following example shows how to set custom callbacks for the model training. ```python - from keras_segmentation.models.unet import vgg_unet from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping -model = vgg_unet(n_classes=51 , input_height=416, input_width=608 ) +model = vgg_unet(n_classes=51, input_height=416, input_width=608) # When using custom callbacks, the default checkpoint saver is removed -callbacks = [ +callbacks=[ ModelCheckpoint( filepath="checkpoints/" + model.name + ".{epoch:05d}", save_weights_only=True, @@ -417,9 +437,9 @@ callbacks = [ ] model.train( - train_images = "dataset1/images_prepped_train/", - train_annotations = "dataset1/annotations_prepped_train/", - checkpoints_path = "/tmp/vgg_unet_1" , epochs=5, + train_images="dataset1/images_prepped_train/", + train_annotations="dataset1/annotations_prepped_train/", + checkpoints_path="/tmp/vgg_unet_1", epochs=5, callbacks=callbacks ) ``` @@ -432,17 +452,15 @@ The following example shows how to add additional image inputs for models. from keras_segmentation.models.unet import vgg_unet -model = vgg_unet(n_classes=51 , input_height=416, input_width=608) +model = vgg_unet(n_classes=51, input_height=416, input_width=608) model.train( - train_images = "dataset1/images_prepped_train/", - train_annotations = "dataset1/annotations_prepped_train/", - checkpoints_path = "/tmp/vgg_unet_1" , epochs=5, + train_images="dataset1/images_prepped_train/", + train_annotations="dataset1/annotations_prepped_train/", + checkpoints_path="/tmp/vgg_unet_1", epochs=5, other_inputs_paths=[ "/path/to/other/directory" ], - - # Ability to add preprocessing preprocessing=[lambda x: x+1, lambda x: x+2, lambda x: x+3], # Different prepocessing for each input # OR @@ -452,7 +470,7 @@ model.train( ## Projects using keras-segmentation -Here are a few projects which are using our library : +Here are a few projects which are using our library: * https://github.com/SteliosTsop/QF-image-segmentation-keras [paper](https://arxiv.org/pdf/1908.02242.pdf) * https://github.com/willembressers/bouquet_quality * https://github.com/jqueguiner/image-segmentation @@ -487,5 +505,4 @@ Here are a few projects which are using our library : * https://github.com/rusito-23/mobile_unet_segmentation * https://github.com/Philliec459/ThinSection-image-segmentation-keras -If you use our code in a publicly available project, please add the link here ( by posting an issue or creating a PR ) - +If you use our code in a publicly available project, please add the link here (by posting an issue or creating a PR) diff --git a/keras_segmentation/train.py b/keras_segmentation/train.py index 6306f719d..a4853613e 100755 --- a/keras_segmentation/train.py +++ b/keras_segmentation/train.py @@ -6,10 +6,10 @@ import six from keras.callbacks import Callback from tensorflow.keras.callbacks import ModelCheckpoint -import tensorflow as tf import glob import sys + def find_latest_checkpoint(checkpoints_path, fail_safe=True): # This is legacy code, there should always be a "checkpoint" file in your directory @@ -41,6 +41,7 @@ def get_epoch_number_from_path(path): return latest_epoch_checkpoint + def masked_categorical_crossentropy(gt, pr): from keras.losses import categorical_crossentropy mask = 1 - gt[:, :, 0] @@ -48,6 +49,7 @@ def masked_categorical_crossentropy(gt, pr): class CheckpointsCallback(Callback): + def __init__(self, checkpoints_path): self.checkpoints_path = checkpoints_path @@ -124,7 +126,7 @@ def train(model, config_file = checkpoints_path + "_config.json" dir_name = os.path.dirname(config_file) - if ( not os.path.exists(dir_name) ) and len( dir_name ) > 0 : + if (not os.path.exists(dir_name)) and len(dir_name) > 0: os.makedirs(dir_name) with open(config_file, "w") as f: @@ -179,14 +181,14 @@ def train(model, other_inputs_paths=other_inputs_paths, preprocessing=preprocessing, read_image_type=read_image_type) - if callbacks is None and (not checkpoints_path is None) : + if callbacks is None and (checkpoints_path is not None): default_callback = ModelCheckpoint( filepath=checkpoints_path + ".{epoch:05d}", save_weights_only=True, verbose=True ) - if sys.version_info[0] < 3: # for pyhton 2 + if sys.version_info[0] < 3: # for python 2 default_callback = CheckpointsCallback(checkpoints_path) callbacks = [ @@ -197,12 +199,14 @@ def train(model, callbacks = [] if not validate: - model.fit(train_gen, steps_per_epoch=steps_per_epoch, - epochs=epochs, callbacks=callbacks, initial_epoch=initial_epoch) + history = model.fit(train_gen, steps_per_epoch=steps_per_epoch, + epochs=epochs, callbacks=callbacks, initial_epoch=initial_epoch) else: - model.fit(train_gen, - steps_per_epoch=steps_per_epoch, - validation_data=val_gen, - validation_steps=val_steps_per_epoch, - epochs=epochs, callbacks=callbacks, - use_multiprocessing=gen_use_multiprocessing, initial_epoch=initial_epoch) + history = model.fit(train_gen, + steps_per_epoch=steps_per_epoch, + validation_data=val_gen, + validation_steps=val_steps_per_epoch, + epochs=epochs, callbacks=callbacks, + use_multiprocessing=gen_use_multiprocessing, initial_epoch=initial_epoch) + + return history