Skip to content

Commit a08ba93

Browse files
authored
Add threshold param for metrics (#68) (#106)
* Add `threshold` param for metrics (#68) * Add test for `threshold` * Smooth set to 1. (metrics and losses)
1 parent ce52a1a commit a08ba93

File tree

3 files changed

+39
-18
lines changed

3 files changed

+39
-18
lines changed

segmentation_models/losses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .metrics import jaccard_score, f_score
77

8-
SMOOTH = 1e-12
8+
SMOOTH = 1.
99

1010
__all__ = [
1111
'jaccard_loss', 'bce_jaccard_loss', 'cce_jaccard_loss',

segmentation_models/metrics.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
'get_f_score', 'get_iou_score', 'get_jaccard_score',
77
]
88

9-
SMOOTH = 1e-12
9+
SMOOTH = 1.
1010

1111

1212
# ============================ Jaccard/IoU score ============================
1313

1414

15-
def iou_score(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
15+
def iou_score(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True, threshold=None):
1616
r""" The `Jaccard index`_, also known as Intersection over Union and the Jaccard similarity coefficient
1717
(originally coined coefficient de communauté by Paul Jaccard), is a statistic used for comparing the
1818
similarity and diversity of sample sets. The Jaccard coefficient measures similarity between finite sample sets,
@@ -27,6 +27,7 @@ def iou_score(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
2727
smooth: value to avoid division by zero
2828
per_image: if ``True``, metric is calculated as mean over images in batch (B),
2929
else over whole batch
30+
threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction prediction will not be round
3031
3132
Returns:
3233
IoU/Jaccard score in range [0, 1]
@@ -38,6 +39,10 @@ def iou_score(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
3839
axes = [1, 2]
3940
else:
4041
axes = [0, 1, 2]
42+
43+
if threshold is not None:
44+
pr = K.greater(pr, threshold)
45+
pr = K.cast(pr, K.floatx())
4146

4247
intersection = K.sum(gt * pr, axis=axes)
4348
union = K.sum(gt + pr, axis=axes) - intersection
@@ -53,20 +58,21 @@ def iou_score(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
5358
return iou
5459

5560

56-
def get_iou_score(class_weights=1., smooth=SMOOTH, per_image=True):
61+
def get_iou_score(class_weights=1., smooth=SMOOTH, per_image=True, threshold=None):
5762
"""Change default parameters of IoU/Jaccard score
5863
5964
Args:
6065
class_weights: 1. or list of class weights, len(weights) = C
6166
smooth: value to avoid division by zero
6267
per_image: if ``True``, metric is calculated as mean over images in batch (B),
6368
else over whole batch
69+
threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction prediction will not be round
6470
6571
Returns:
6672
``callable``: IoU/Jaccard score
6773
"""
6874
def score(gt, pr):
69-
return iou_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image)
75+
return iou_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image, threshold=threshold)
7076

7177
return score
7278

@@ -83,7 +89,7 @@ def score(gt, pr):
8389

8490
# ============================== F/Dice - score ==============================
8591

86-
def f_score(gt, pr, class_weights=1, beta=1, smooth=SMOOTH, per_image=True):
92+
def f_score(gt, pr, class_weights=1, beta=1, smooth=SMOOTH, per_image=True, threshold=None):
8793
r"""The F-score (Dice coefficient) can be interpreted as a weighted average of the precision and recall,
8894
where an F-score reaches its best value at 1 and worst score at 0.
8995
The relative contribution of ``precision`` and ``recall`` to the F1-score are equal.
@@ -110,6 +116,7 @@ def f_score(gt, pr, class_weights=1, beta=1, smooth=SMOOTH, per_image=True):
110116
smooth: value to avoid division by zero
111117
per_image: if ``True``, metric is calculated as mean over images in batch (B),
112118
else over whole batch
119+
threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction prediction will not be round
113120
114121
Returns:
115122
F-score in range [0, 1]
@@ -119,6 +126,10 @@ def f_score(gt, pr, class_weights=1, beta=1, smooth=SMOOTH, per_image=True):
119126
axes = [1, 2]
120127
else:
121128
axes = [0, 1, 2]
129+
130+
if threshold is not None:
131+
pr = K.greater(pr, threshold)
132+
pr = K.cast(pr, K.floatx())
122133

123134
tp = K.sum(gt * pr, axis=axes)
124135
fp = K.sum(pr, axis=axes) - tp
@@ -137,7 +148,7 @@ def f_score(gt, pr, class_weights=1, beta=1, smooth=SMOOTH, per_image=True):
137148
return score
138149

139150

140-
def get_f_score(class_weights=1, beta=1, smooth=SMOOTH, per_image=True):
151+
def get_f_score(class_weights=1, beta=1, smooth=SMOOTH, per_image=True, threshold=None):
141152
"""Change default parameters of F-score score
142153
143154
Args:
@@ -146,12 +157,13 @@ def get_f_score(class_weights=1, beta=1, smooth=SMOOTH, per_image=True):
146157
beta: f-score coefficient
147158
per_image: if ``True``, metric is calculated as mean over images in batch (B),
148159
else over whole batch
160+
threshold: value to round predictions (use ``>`` comparison), if ``None`` prediction prediction will not be round
149161
150162
Returns:
151163
``callable``: F-score
152164
"""
153165
def score(gt, pr):
154-
return f_score(gt, pr, class_weights=class_weights, beta=beta, smooth=smooth, per_image=per_image)
166+
return f_score(gt, pr, class_weights=class_weights, beta=beta, smooth=smooth, per_image=per_image, threshold=threshold)
155167

156168
return score
157169

tests/test_metrics.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_iou_metric(case):
120120
gt, pr, res = case
121121
gt = _to_4d(gt)
122122
pr = _to_4d(pr)
123-
score = K.eval(iou_score(gt, pr))
123+
score = K.eval(iou_score(gt, pr, smooth=10e-12))
124124
assert np.allclose(score, res)
125125

126126

@@ -129,15 +129,15 @@ def test_jaccrad_loss(case):
129129
gt, pr, res = case
130130
gt = _to_4d(gt)
131131
pr = _to_4d(pr)
132-
score = K.eval(jaccard_loss(gt, pr))
132+
score = K.eval(jaccard_loss(gt, pr, smooth=10e-12))
133133
assert np.allclose(score, 1 - res)
134134

135135

136136
def _test_f_metric(case, beta=1):
137137
gt, pr, res = case
138138
gt = _to_4d(gt)
139139
pr = _to_4d(pr)
140-
score = K.eval(f_score(gt, pr, beta=beta))
140+
score = K.eval(f_score(gt, pr, beta=beta, smooth=10e-12))
141141
assert np.allclose(score, res)
142142

143143

@@ -156,7 +156,7 @@ def test_dice_loss(case):
156156
gt, pr, res = case
157157
gt = _to_4d(gt)
158158
pr = _to_4d(pr)
159-
score = K.eval(dice_loss(gt, pr))
159+
score = K.eval(dice_loss(gt, pr, smooth=10e-12))
160160
assert np.allclose(score, 1 - res)
161161

162162

@@ -169,10 +169,10 @@ def test_per_image(func):
169169
pr = _add_4d(pr)
170170

171171
# calculate score per image
172-
score_1 = K.eval(func(gt, pr, per_image=True))
172+
score_1 = K.eval(func(gt, pr, per_image=True, smooth=10e-12))
173173
score_2 = np.mean([
174-
K.eval(func(_to_4d(GT0), _to_4d(PR1))),
175-
K.eval(func(_to_4d(GT1), _to_4d(PR2))),
174+
K.eval(func(_to_4d(GT0), _to_4d(PR1), smooth=10e-12)),
175+
K.eval(func(_to_4d(GT1), _to_4d(PR2), smooth=10e-12)),
176176
])
177177
assert np.allclose(score_1, score_2)
178178

@@ -186,14 +186,23 @@ def test_per_batch(func):
186186
pr = _add_4d(pr)
187187

188188
# calculate score per batch
189-
score_1 = K.eval(func(gt, pr, per_image=False))
189+
score_1 = K.eval(func(gt, pr, per_image=False, smooth=10e-12))
190190

191191
gt1 = np.concatenate([GT0, GT1], axis=0)
192192
pr1 = np.concatenate([PR1, PR2], axis=0)
193-
score_2 = K.eval(func(_to_4d(gt1), _to_4d(pr1), per_image=True))
193+
score_2 = K.eval(func(_to_4d(gt1), _to_4d(pr1), per_image=True, smooth=10e-12))
194194

195195
assert np.allclose(score_1, score_2)
196-
196+
197+
198+
@pytest.mark.parametrize('case', IOU_CASES)
199+
def test_threshold_iou(case):
200+
gt, pr, res = case
201+
gt = _to_4d(gt)
202+
pr = _to_4d(pr) * 0.51
203+
score = K.eval(iou_score(gt, pr, smooth=10e-12, threshold=0.5))
204+
assert np.allclose(score, res)
205+
197206

198207
if __name__ == '__main__':
199208
pytest.main([__file__])

0 commit comments

Comments
 (0)