Skip to content

Commit 2eb584a

Browse files
authored
Add beta to dice; Imporve losses docs (#103)
1 parent 2ee8324 commit 2eb584a

File tree

1 file changed

+73
-6
lines changed

1 file changed

+73
-6
lines changed

segmentation_models/losses.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,44 @@ def jaccard_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
3636

3737

3838
def bce_jaccard_loss(gt, pr, bce_weight=1., smooth=SMOOTH, per_image=True):
39+
r"""Sum of binary crossentropy and jaccard losses:
40+
41+
.. math:: L(A, B) = bce_weight * binary_crossentropy(A, B) + jaccard_loss(A, B)
42+
43+
Args:
44+
gt: ground truth 4D keras tensor (B, H, W, C)
45+
pr: prediction 4D keras tensor (B, H, W, C)
46+
class_weights: 1. or list of class weights for jaccard loss, len(weights) = C
47+
smooth: value to avoid division by zero
48+
per_image: if ``True``, jaccard loss is calculated as mean over images in batch (B),
49+
else over whole batch (only for jaccard loss)
50+
51+
Returns:
52+
loss
53+
54+
"""
3955
bce = K.mean(binary_crossentropy(gt, pr))
4056
loss = bce_weight * bce + jaccard_loss(gt, pr, smooth=smooth, per_image=per_image)
4157
return loss
4258

4359

4460
def cce_jaccard_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per_image=True):
61+
r"""Sum of categorical crossentropy and jaccard losses:
62+
63+
.. math:: L(A, B) = cce_weight * categorical_crossentropy(A, B) + jaccard_loss(A, B)
64+
65+
Args:
66+
gt: ground truth 4D keras tensor (B, H, W, C)
67+
pr: prediction 4D keras tensor (B, H, W, C)
68+
class_weights: 1. or list of class weights for jaccard loss, len(weights) = C
69+
smooth: value to avoid division by zero
70+
per_image: if ``True``, jaccard loss is calculated as mean over images in batch (B),
71+
else over whole batch
72+
73+
Returns:
74+
loss
75+
76+
"""
4577
cce = categorical_crossentropy(gt, pr) * class_weights
4678
cce = K.mean(cce)
4779
return cce_weight * cce + jaccard_loss(gt, pr, smooth=smooth, class_weights=class_weights, per_image=per_image)
@@ -57,7 +89,7 @@ def cce_jaccard_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per
5789

5890
# ============================== Dice Losses ================================
5991

60-
def dice_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
92+
def dice_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True, beta=1.):
6193
r"""Dice loss function for imbalanced datasets:
6294
6395
.. math:: L(precision, recall) = 1 - (1 + \beta^2) \frac{precision \cdot recall}
@@ -70,24 +102,59 @@ def dice_loss(gt, pr, class_weights=1., smooth=SMOOTH, per_image=True):
70102
smooth: value to avoid division by zero
71103
per_image: if ``True``, metric is calculated as mean over images in batch (B),
72104
else over whole batch
105+
beta: coefficient for precision recall balance
73106
74107
Returns:
75108
Dice loss in range [0, 1]
76109
77110
"""
78-
return 1 - f_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image, beta=1.)
111+
return 1 - f_score(gt, pr, class_weights=class_weights, smooth=smooth, per_image=per_image, beta=beta)
79112

80113

81-
def bce_dice_loss(gt, pr, bce_weight=1., smooth=SMOOTH, per_image=True):
114+
def bce_dice_loss(gt, pr, bce_weight=1., smooth=SMOOTH, per_image=True, beta=1.):
115+
r"""Sum of binary crossentropy and dice losses:
116+
117+
.. math:: L(A, B) = bce_weight * binary_crossentropy(A, B) + dice_loss(A, B)
118+
119+
Args:
120+
gt: ground truth 4D keras tensor (B, H, W, C)
121+
pr: prediction 4D keras tensor (B, H, W, C)
122+
class_weights: 1. or list of class weights for dice loss, len(weights) = C
123+
smooth: value to avoid division by zero
124+
per_image: if ``True``, dice loss is calculated as mean over images in batch (B),
125+
else over whole batch
126+
beta: coefficient for precision recall balance
127+
128+
Returns:
129+
loss
130+
131+
"""
82132
bce = K.mean(binary_crossentropy(gt, pr))
83-
loss = bce_weight * bce + dice_loss(gt, pr, smooth=smooth, per_image=per_image)
133+
loss = bce_weight * bce + dice_loss(gt, pr, smooth=smooth, per_image=per_image, beta=beta)
84134
return loss
85135

86136

87-
def cce_dice_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per_image=True):
137+
def cce_dice_loss(gt, pr, cce_weight=1., class_weights=1., smooth=SMOOTH, per_image=True, beta=1.):
138+
r"""Sum of categorical crossentropy and dice losses:
139+
140+
.. math:: L(A, B) = cce_weight * categorical_crossentropy(A, B) + dice_loss(A, B)
141+
142+
Args:
143+
gt: ground truth 4D keras tensor (B, H, W, C)
144+
pr: prediction 4D keras tensor (B, H, W, C)
145+
class_weights: 1. or list of class weights for dice loss, len(weights) = C
146+
smooth: value to avoid division by zero
147+
per_image: if ``True``, dice loss is calculated as mean over images in batch (B),
148+
else over whole batch
149+
beta: coefficient for precision recall balance
150+
151+
Returns:
152+
loss
153+
154+
"""
88155
cce = categorical_crossentropy(gt, pr) * class_weights
89156
cce = K.mean(cce)
90-
return cce_weight * cce + dice_loss(gt, pr, smooth=smooth, class_weights=class_weights, per_image=per_image)
157+
return cce_weight * cce + dice_loss(gt, pr, smooth=smooth, class_weights=class_weights, per_image=per_image, beta=beta)
91158

92159

93160
# Update custom objects

0 commit comments

Comments
 (0)