You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Support for keras and tf.keras
* Focal loss; precision and recall metrics
* New losses and metrics functionality: aggregation and multiplication by factor
* NCHW and NHWC support
* Removed pure `tf` operations to work with other keras backends
* Reduced a number of custom objects for better models serialization and deserialization
# or keras.backend.set_image_data_format('channels_first')
78
+
79
+
Created segmentaion model is just an instance of Keras Model, which can be build as easy as:
80
+
81
+
.. code:: python
51
82
52
-
model = Unet()
83
+
model =sm.Unet()
53
84
54
85
Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:
55
86
56
87
.. code:: python
57
88
58
-
model = Unet('resnet34', encoder_weights='imagenet')
89
+
model =sm.Unet('resnet34', encoder_weights='imagenet')
59
90
60
91
Change number of output classes in the model (choose your case):
61
92
62
93
.. code:: python
63
94
64
95
# binary segmentation (this parameters are default when you call Unet('resnet34')
65
-
model = Unet('resnet34', classes=1, activation='sigmoid')
96
+
model =sm.Unet('resnet34', classes=1, activation='sigmoid')
66
97
67
98
.. code:: python
68
99
69
100
# multiclass segmentation with non overlapping class masks (your classes + background)
70
-
model = Unet('resnet34', classes=3, activation='softmax')
101
+
model =sm.Unet('resnet34', classes=3, activation='softmax')
71
102
72
103
.. code:: python
73
104
74
105
# multiclass segmentation with independent overlapping/non-overlapping class masks
75
-
model = Unet('resnet34', classes=3, activation='sigmoid')
106
+
model =sm.Unet('resnet34', classes=3, activation='sigmoid')
76
107
77
108
78
109
Change input shape of the model:
@@ -88,39 +119,45 @@ Simple training pipeline
88
119
89
120
.. code:: python
90
121
91
-
from segmentation_models import Unet
92
-
from segmentation_models.backbones import get_preprocessing
93
-
from segmentation_models.losses import bce_jaccard_loss
94
-
from segmentation_models.metrics import iou_score
95
-
96
-
BACKBONE='resnet34'
97
-
preprocess_input = get_preprocessing(BACKBONE)
98
-
99
-
# load your data
100
-
x_train, y_train, x_val, y_val = load_data(...)
101
-
102
-
# preprocess input
103
-
x_train = preprocess_input(x_train)
104
-
x_val = preprocess_input(x_val)
105
-
106
-
# define model
107
-
model = Unet(BACKBONE, encoder_weights='imagenet')
# if you use data generator use model.fit_generator(...) instead of model.fit(...)
112
-
# more about `fit_generator` here: https://keras.io/models/sequential/#fit_generator
113
-
model.fit(
114
-
x=x_train,
115
-
y=y_train,
116
-
batch_size=16,
122
+
import segmentation_models as sm
123
+
124
+
BACKBONE='resnet34'
125
+
preprocess_input = sm.get_preprocessing(BACKBONE)
126
+
127
+
# load your data
128
+
x_train, y_train, x_val, y_val = load_data(...)
129
+
130
+
# preprocess input
131
+
x_train = preprocess_input(x_train)
132
+
x_val = preprocess_input(x_val)
133
+
134
+
# define model
135
+
model = sm.Unet(BACKBONE, encoder_weights='imagenet')
136
+
model.compile(
137
+
'Adam',
138
+
loss=sm.losses.bce_jaccard_loss,
139
+
metrics=[sm.metrics.iou_score],
140
+
)
141
+
142
+
# fit model
143
+
# if you use data generator use model.fit_generator(...) instead of model.fit(...)
144
+
# more about `fit_generator` here: https://keras.io/models/sequential/#fit_generator
145
+
model.fit(
146
+
x=x_train,
147
+
y=y_train,
148
+
batch_size=16,
117
149
epochs=100,
118
150
validation_data=(x_val, y_val),
119
-
)
120
-
151
+
)
121
152
122
153
Same manimulations can be done with ``Linknet``, ``PSPNet`` and ``FPN``. For more detailed information about models API and use cases `Read the Docs <https://segmentation-models.readthedocs.io/en/latest/>`__.
0 commit comments