diff --git a/README.md b/README.md
index c31213855..172979aad 100644
--- a/README.md
+++ b/README.md
@@ -6,14 +6,14 @@
[](http://perso.crans.org/besson/LICENSE.html)
[](https://twitter.com/divamgupta)
-
-
Implementation of various Deep Image Segmentation models in keras.
-Link to the full blog post with tutorial : https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html
+Link to the full blog post with tutorial: https://divamgupta.com/image-segmentation/2019/06/06/deep-learning-semantic-segmentation-keras.html
-
+
## Working Google Colab Examples:
@@ -62,19 +62,19 @@ Following models are supported:
Example results for the pre-trained models provided :
-Input Image | Output Segmentation Image
-:-------------------------:|:-------------------------:
- | 
- | 
+Input Image | Output Segmentation Image
+:--------------------------------------------------:|:---------------------------------------------------------:
+ | 
+ | 
## Getting Started
### Prerequisites
-* Keras ( recommended version : 2.4.3 )
+* Keras (recommended version: 2.4.3)
* OpenCV for Python
-* Tensorflow ( recommended version : 2.4.1 )
+* Tensorflow (recommended version: 2.4.1)
```shell
apt-get install -y libsm6 libxext6 libxrender-dev
@@ -90,13 +90,13 @@ Recommended way:
pip install --upgrade git+https://github.com/divamgupta/image-segmentation-keras
```
-### or
+#### or
```shell
pip install keras-segmentation
```
-### or
+#### or
```shell
git clone https://github.com/divamgupta/image-segmentation-keras
@@ -145,7 +145,7 @@ import cv2
import numpy as np
ann_img = np.zeros((30,30,3)).astype('uint8')
-ann_img[ 3 , 4 ] = 1 # this would set the label of pixel 3,4 as 1
+ann_img[ 3 , 4 ] = 1 # this would set the label of pixel 3, 4 as 1
cv2.imwrite( "ann_1.png" ,ann_img )
```
@@ -171,9 +171,9 @@ from keras_segmentation.models.unet import vgg_unet
model = vgg_unet(n_classes=51 , input_height=416, input_width=608 )
model.train(
- train_images = "dataset1/images_prepped_train/",
- train_annotations = "dataset1/annotations_prepped_train/",
- checkpoints_path = "/tmp/vgg_unet_1" , epochs=5
+ train_images="dataset1/images_prepped_train/",
+ train_annotations="dataset1/annotations_prepped_train/",
+ checkpoints_path="/tmp/vgg_unet_1", epochs=5
)
out = model.predict_segmentation(
@@ -185,7 +185,8 @@ import matplotlib.pyplot as plt
plt.imshow(out)
# evaluating the model
-print(model.evaluate_segmentation( inp_images_dir="dataset1/images_prepped_test/" , annotations_dir="dataset1/annotations_prepped_test/" ) )
+print(model.evaluate_segmentation(inp_images_dir="dataset1/images_prepped_test/",
+ annotations_dir="dataset1/annotations_prepped_test/" ) )
```
@@ -292,14 +293,35 @@ new_model = pspnet_50( n_classes=51 )
transfer_weights( pretrained_model , new_model ) # transfer weights from pre-trained model to your model
-new_model.train(
- train_images = "dataset1/images_prepped_train/",
- train_annotations = "dataset1/annotations_prepped_train/",
- checkpoints_path = "/tmp/vgg_unet_1" , epochs=5
+history = new_model.train(
+ train_images="dataset1/images_prepped_train/",
+ train_annotations="dataset1/annotations_prepped_train/",
+ checkpoints_path="/tmp/vgg_unet_1", epochs=5
)
+```
+Note that `history` is a [History](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/History?) object which
+contain logs from the training, so once the training finishes you can plot the loss and accuracy as usual:
-```
+```python
+import matplotlib.pyplot as plt
+
+fig, (ax0, ax1) = plt.subplots(1,2, figsize=(12,4))
+
+ax0.plot(history.history['loss'], label='train loss')
+ax0.set_title('Loss function')
+ax0.set_xlabel('epochs')
+ax0.set_ylabel('loss')
+ax0.legend()
+
+ax1.plot(history.history['accuracy'], label='train acc')
+ax1.set_title('Accuracy')
+ax1.set_xlabel('epochs')
+ax1.set_ylabel('loss')
+ax1.legend()
+
+plt.show()
+```
@@ -312,11 +334,11 @@ from keras_segmentation.predict import model_from_checkpoint_path
from keras_segmentation.models.unet import unet_mini
from keras_segmentation.model_compression import perform_distilation
-model_large = model_from_checkpoint_path( "/checkpoints/path/of/trained/model" )
-model_small = unet_mini( n_classes=51, input_height=300, input_width=400 )
+model_large = model_from_checkpoint_path("/checkpoints/path/of/trained/model")
+model_small = unet_mini(n_classes=51, input_height=300, input_width=400)
-perform_distilation ( data_path="/path/to/large_image_set/" , checkpoints_path="path/to/save/checkpoints" ,
- teacher_model=model_large , student_model=model_small , distilation_loss='kl' , feats_distilation_loss='pa' )
+perform_distilation (data_path="/path/to/large_image_set/" , checkpoints_path="path/to/save/checkpoints" ,
+ teacher_model=model_large, student_model=model_small, distilation_loss='kl', feats_distilation_loss='pa')
```
@@ -338,10 +360,10 @@ def custom_augmentation():
[
# apply the following augmenters to most images
iaa.Fliplr(0.5), # horizontally flip 50% of all images
- iaa.Flipud(0.5), # horizontally flip 50% of all images
+ iaa.Flipud(0.5) # vertically flip 50% of all images
])
-model = vgg_unet(n_classes=51 , input_height=416, input_width=608)
+model = vgg_unet(n_classes=51, input_height=416, input_width=608)
model.train(
train_images = "dataset1/images_prepped_train/",
@@ -364,9 +386,9 @@ model = vgg_unet(n_classes=51 , input_height=416, input_width=608,
)
model.train(
- train_images = "dataset1/images_prepped_train/",
- train_annotations = "dataset1/annotations_prepped_train/",
- checkpoints_path = "/tmp/vgg_unet_1" , epochs=5,
+ train_images="dataset1/images_prepped_train/",
+ train_annotations="dataset1/annotations_prepped_train/",
+ checkpoints_path="/tmp/vgg_unet_1" , epochs=5,
read_image_type=0 # Sets how opencv will read the images
# cv2.IMREAD_COLOR = 1 (rgb),
# cv2.IMREAD_GRAYSCALE = 0,
@@ -379,18 +401,17 @@ model.train(
The following example shows how to set a custom image preprocessing function.
```python
-
from keras_segmentation.models.unet import vgg_unet
def image_preprocessing(image):
return image + 1
-model = vgg_unet(n_classes=51 , input_height=416, input_width=608)
+model = vgg_unet(n_classes=51, input_height=416, input_width=608)
model.train(
- train_images = "dataset1/images_prepped_train/",
- train_annotations = "dataset1/annotations_prepped_train/",
- checkpoints_path = "/tmp/vgg_unet_1" , epochs=5,
+ train_images="dataset1/images_prepped_train/",
+ train_annotations="dataset1/annotations_prepped_train/",
+ checkpoints_path="/tmp/vgg_unet_1", epochs=5,
preprocessing=image_preprocessing # Sets the preprocessing function
)
```
@@ -400,14 +421,13 @@ model.train(
The following example shows how to set custom callbacks for the model training.
```python
-
from keras_segmentation.models.unet import vgg_unet
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
-model = vgg_unet(n_classes=51 , input_height=416, input_width=608 )
+model = vgg_unet(n_classes=51, input_height=416, input_width=608)
# When using custom callbacks, the default checkpoint saver is removed
-callbacks = [
+callbacks=[
ModelCheckpoint(
filepath="checkpoints/" + model.name + ".{epoch:05d}",
save_weights_only=True,
@@ -417,9 +437,9 @@ callbacks = [
]
model.train(
- train_images = "dataset1/images_prepped_train/",
- train_annotations = "dataset1/annotations_prepped_train/",
- checkpoints_path = "/tmp/vgg_unet_1" , epochs=5,
+ train_images="dataset1/images_prepped_train/",
+ train_annotations="dataset1/annotations_prepped_train/",
+ checkpoints_path="/tmp/vgg_unet_1", epochs=5,
callbacks=callbacks
)
```
@@ -432,17 +452,15 @@ The following example shows how to add additional image inputs for models.
from keras_segmentation.models.unet import vgg_unet
-model = vgg_unet(n_classes=51 , input_height=416, input_width=608)
+model = vgg_unet(n_classes=51, input_height=416, input_width=608)
model.train(
- train_images = "dataset1/images_prepped_train/",
- train_annotations = "dataset1/annotations_prepped_train/",
- checkpoints_path = "/tmp/vgg_unet_1" , epochs=5,
+ train_images="dataset1/images_prepped_train/",
+ train_annotations="dataset1/annotations_prepped_train/",
+ checkpoints_path="/tmp/vgg_unet_1", epochs=5,
other_inputs_paths=[
"/path/to/other/directory"
],
-
-
# Ability to add preprocessing
preprocessing=[lambda x: x+1, lambda x: x+2, lambda x: x+3], # Different prepocessing for each input
# OR
@@ -452,7 +470,7 @@ model.train(
## Projects using keras-segmentation
-Here are a few projects which are using our library :
+Here are a few projects which are using our library:
* https://github.com/SteliosTsop/QF-image-segmentation-keras [paper](https://arxiv.org/pdf/1908.02242.pdf)
* https://github.com/willembressers/bouquet_quality
* https://github.com/jqueguiner/image-segmentation
@@ -487,5 +505,4 @@ Here are a few projects which are using our library :
* https://github.com/rusito-23/mobile_unet_segmentation
* https://github.com/Philliec459/ThinSection-image-segmentation-keras
-If you use our code in a publicly available project, please add the link here ( by posting an issue or creating a PR )
-
+If you use our code in a publicly available project, please add the link here (by posting an issue or creating a PR)
diff --git a/keras_segmentation/train.py b/keras_segmentation/train.py
index 6306f719d..a4853613e 100755
--- a/keras_segmentation/train.py
+++ b/keras_segmentation/train.py
@@ -6,10 +6,10 @@
import six
from keras.callbacks import Callback
from tensorflow.keras.callbacks import ModelCheckpoint
-import tensorflow as tf
import glob
import sys
+
def find_latest_checkpoint(checkpoints_path, fail_safe=True):
# This is legacy code, there should always be a "checkpoint" file in your directory
@@ -41,6 +41,7 @@ def get_epoch_number_from_path(path):
return latest_epoch_checkpoint
+
def masked_categorical_crossentropy(gt, pr):
from keras.losses import categorical_crossentropy
mask = 1 - gt[:, :, 0]
@@ -48,6 +49,7 @@ def masked_categorical_crossentropy(gt, pr):
class CheckpointsCallback(Callback):
+
def __init__(self, checkpoints_path):
self.checkpoints_path = checkpoints_path
@@ -124,7 +126,7 @@ def train(model,
config_file = checkpoints_path + "_config.json"
dir_name = os.path.dirname(config_file)
- if ( not os.path.exists(dir_name) ) and len( dir_name ) > 0 :
+ if (not os.path.exists(dir_name)) and len(dir_name) > 0:
os.makedirs(dir_name)
with open(config_file, "w") as f:
@@ -179,14 +181,14 @@ def train(model,
other_inputs_paths=other_inputs_paths,
preprocessing=preprocessing, read_image_type=read_image_type)
- if callbacks is None and (not checkpoints_path is None) :
+ if callbacks is None and (checkpoints_path is not None):
default_callback = ModelCheckpoint(
filepath=checkpoints_path + ".{epoch:05d}",
save_weights_only=True,
verbose=True
)
- if sys.version_info[0] < 3: # for pyhton 2
+ if sys.version_info[0] < 3: # for python 2
default_callback = CheckpointsCallback(checkpoints_path)
callbacks = [
@@ -197,12 +199,14 @@ def train(model,
callbacks = []
if not validate:
- model.fit(train_gen, steps_per_epoch=steps_per_epoch,
- epochs=epochs, callbacks=callbacks, initial_epoch=initial_epoch)
+ history = model.fit(train_gen, steps_per_epoch=steps_per_epoch,
+ epochs=epochs, callbacks=callbacks, initial_epoch=initial_epoch)
else:
- model.fit(train_gen,
- steps_per_epoch=steps_per_epoch,
- validation_data=val_gen,
- validation_steps=val_steps_per_epoch,
- epochs=epochs, callbacks=callbacks,
- use_multiprocessing=gen_use_multiprocessing, initial_epoch=initial_epoch)
+ history = model.fit(train_gen,
+ steps_per_epoch=steps_per_epoch,
+ validation_data=val_gen,
+ validation_steps=val_steps_per_epoch,
+ epochs=epochs, callbacks=callbacks,
+ use_multiprocessing=gen_use_multiprocessing, initial_epoch=initial_epoch)
+
+ return history