Skip to content

Commit 943a9cd

Browse files
committed
feat(losses): Add regularizers to losses
1 parent eb52ee4 commit 943a9cd

File tree

11 files changed

+394
-47
lines changed

11 files changed

+394
-47
lines changed

cellseg_models_pytorch/losses/criterions/ce.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,47 @@
11
import torch
2-
import torch.nn as nn
2+
import torch.nn.functional as F
33

4+
from ...utils import tensor_one_hot
45
from ..weighted_base_loss import WeightedBaseLoss
56

7+
__all__ = ["CELoss"]
8+
69

710
class CELoss(WeightedBaseLoss):
811
def __init__(
9-
self, edge_weight: float = None, class_weights: torch.Tensor = None, **kwargs
12+
self,
13+
apply_sd: bool = False,
14+
apply_ls: bool = False,
15+
apply_svls: bool = False,
16+
edge_weight: float = None,
17+
class_weights: torch.Tensor = None,
18+
**kwargs,
1019
) -> None:
1120
"""Cross-Entropy loss with weighting.
1221
1322
Parameters
1423
----------
15-
edge_weight : float, default=none
24+
apply_sd : bool, default=False
25+
If True, Spectral decoupling regularization will be applied to the
26+
loss matrix.
27+
apply_ls : bool, default=False
28+
If True, Label smoothing will be applied to the target.
29+
apply_svls : bool, default=False
30+
If True, spatially varying label smoothing will be applied to the target
31+
edge_weight : float, default=None
1632
Weight that is added to object borders.
1733
class_weights : torch.Tensor, default=None
1834
Class weights. A tensor of shape (n_classes,).
1935
"""
20-
super().__init__(class_weights, edge_weight)
21-
self.loss = nn.CrossEntropyLoss(reduction="none", weight=class_weights)
36+
super().__init__(apply_sd, apply_ls, apply_svls, class_weights, edge_weight)
37+
self.eps = 1e-8
2238

2339
def forward(
2440
self,
2541
yhat: torch.Tensor,
2642
target: torch.Tensor,
2743
target_weight: torch.Tensor = None,
28-
**kwargs
44+
**kwargs,
2945
) -> torch.Tensor:
3046
"""Compute the cross entropy loss.
3147
@@ -43,7 +59,28 @@ def forward(
4359
torch.Tensor:
4460
Computed CE loss (scalar).
4561
"""
46-
loss = self.loss(yhat, target) # (B, H, W)
62+
input_soft = F.softmax(yhat, dim=1) + self.eps # (B, C, H, W)
63+
num_classes = yhat.shape[1]
64+
target_one_hot = tensor_one_hot(target, num_classes) # (B, C, H, W)
65+
assert target_one_hot.shape == yhat.shape
66+
67+
if self.apply_svls:
68+
target_one_hot = self.apply_svls_to_target(
69+
target_one_hot, num_classes, **kwargs
70+
)
71+
72+
if self.apply_ls:
73+
target_one_hot = self.apply_ls_to_target(
74+
target_one_hot, num_classes, **kwargs
75+
)
76+
77+
loss = -torch.sum(target_one_hot * torch.log(input_soft), dim=1) # (B, H, W)
78+
79+
if self.apply_sd:
80+
loss = self.apply_spectral_decouple(loss, yhat)
81+
82+
if self.class_weights is not None:
83+
loss = self.apply_class_weights(loss, target)
4784

4885
if self.edge_weight is not None:
4986
loss = self.apply_edge_weights(loss, target_weight)

cellseg_models_pytorch/losses/criterions/dice.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,31 +5,46 @@
55

66
from ..weighted_base_loss import WeightedBaseLoss
77

8+
__all__ = ["DiceLoss"]
9+
810

911
class DiceLoss(WeightedBaseLoss):
1012
def __init__(
11-
self, edge_weight: float = None, class_weights: torch.Tensor = None, **kwargs
13+
self,
14+
apply_sd: bool = False,
15+
apply_ls: bool = False,
16+
apply_svls: bool = False,
17+
edge_weight: float = None,
18+
class_weights: torch.Tensor = None,
19+
**kwargs,
1220
) -> None:
1321
"""Sørensen-Dice Coefficient Loss.
1422
1523
Optionally applies weights at the object edges and classes.
1624
1725
Parameters
1826
----------
27+
apply_sd : bool, default=False
28+
If True, Spectral decoupling regularization will be applied to the
29+
loss matrix.
30+
apply_ls : bool, default=False
31+
If True, Label smoothing will be applied to the target.
32+
apply_svls : bool, default=False
33+
If True, spatially varying label smoothing will be applied to the target
1934
edge_weight : float, default=none
2035
Weight that is added to object borders.
2136
class_weights : torch.Tensor, default=None
2237
Class weights. A tensor of shape (n_classes,).
2338
"""
24-
super().__init__(class_weights, edge_weight)
25-
self.eps = 1e-6
39+
super().__init__(apply_sd, apply_ls, apply_svls, class_weights, edge_weight)
40+
self.eps = 1e-8
2641

2742
def forward(
2843
self,
2944
yhat: torch.Tensor,
3045
target: torch.Tensor,
3146
target_weight: torch.Tensor = None,
32-
**kwargs
47+
**kwargs,
3348
) -> torch.Tensor:
3449
"""Compute the DICE coefficient.
3550
@@ -48,12 +63,26 @@ def forward(
4863
Computed DICE loss (scalar).
4964
"""
5065
yhat_soft = F.softmax(yhat, dim=1)
51-
target_one_hot = tensor_one_hot(target, n_classes=yhat.shape[1])
66+
num_classes = yhat.shape[1]
67+
target_one_hot = tensor_one_hot(target, n_classes=num_classes)
5268
assert target_one_hot.shape == yhat.shape
5369

70+
if self.apply_svls:
71+
target_one_hot = self.apply_svls_to_target(
72+
target_one_hot, num_classes, **kwargs
73+
)
74+
75+
if self.apply_ls:
76+
target_one_hot = self.apply_ls_to_target(
77+
target_one_hot, num_classes, **kwargs
78+
)
79+
5480
intersection = torch.sum(yhat_soft * target_one_hot, 1)
5581
union = torch.sum(yhat_soft + target_one_hot, 1)
56-
dice = 2.0 * intersection / union.clamp_min(self.eps)
82+
dice = 2.0 * intersection / union.clamp_min(self.eps) # (B, H, W)
83+
84+
if self.apply_sd:
85+
dice = self.apply_spectral_decouple(dice, yhat)
5786

5887
if self.class_weights is not None:
5988
dice = self.apply_class_weights(dice, target)

cellseg_models_pytorch/losses/criterions/focal.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33

4-
from cellseg_models_pytorch.utils import tensor_one_hot
5-
4+
from ...utils import tensor_one_hot
65
from ..weighted_base_loss import WeightedBaseLoss
76

87

@@ -11,6 +10,9 @@ def __init__(
1110
self,
1211
alpha: float = 0.5,
1312
gamma: float = 2.0,
13+
apply_sd: bool = False,
14+
apply_ls: bool = False,
15+
apply_svls: bool = False,
1416
edge_weight: float = None,
1517
class_weights: torch.Tensor = None,
1618
**kwargs
@@ -19,23 +21,31 @@ def __init__(
1921
2022
https://arxiv.org/abs/1708.02002
2123
22-
Optionally applies weights at the object edges and classes.
24+
Optionally applies, label smoothing, spatially varying label smoothing or
25+
weights at the object edges or class weights to the loss.
2326
2427
Parameters
2528
----------
2629
alpha : float, default=0.5
2730
Weight factor b/w [0,1].
2831
gamma : float, default=2.0
2932
Focusing factor.
33+
apply_sd : bool, default=False
34+
If True, Spectral decoupling regularization will be applied to the
35+
loss matrix.
36+
apply_ls : bool, default=False
37+
If True, Label smoothing will be applied to the target.
38+
apply_svls : bool, default=False
39+
If True, spatially varying label smoothing will be applied to the target
3040
edge_weight : float, default=none
3141
Weight that is added to object borders.
3242
class_weights : torch.Tensor, default=None
3343
Class weights. A tensor of shape (n_classes,).
3444
"""
35-
super().__init__(class_weights, edge_weight)
45+
super().__init__(apply_sd, apply_ls, apply_svls, class_weights, edge_weight)
3646
self.alpha = alpha
3747
self.gamma = gamma
38-
self.eps = 1e-6
48+
self.eps = 1e-8
3949

4050
def forward(
4151
self,
@@ -65,12 +75,25 @@ def forward(
6575
target_one_hot = tensor_one_hot(target, num_classes) # (B, C, H, W)
6676
assert target_one_hot.shape == yhat.shape
6777

78+
if self.apply_svls:
79+
target_one_hot = self.apply_svls_to_target(
80+
target_one_hot, num_classes, **kwargs
81+
)
82+
83+
if self.apply_ls:
84+
target_one_hot = self.apply_ls_to_target(
85+
target_one_hot, num_classes, **kwargs
86+
)
87+
6888
weight = (1.0 - input_soft) ** self.gamma
6989
focal = self.alpha * weight * torch.log(input_soft)
7090
focal = target_one_hot * focal
7191

7292
loss = -torch.sum(focal, dim=1) # to (B, H, W)
7393

94+
if self.apply_sd:
95+
loss = self.apply_spectral_decouple(loss, yhat)
96+
7497
if self.class_weights is not None:
7598
loss = self.apply_class_weights(loss, target)
7699

cellseg_models_pytorch/losses/criterions/iou.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,41 @@
88

99
class IoULoss(WeightedBaseLoss):
1010
def __init__(
11-
self, edge_weight: float = None, class_weights: torch.Tensor = None, **kwargs
11+
self,
12+
apply_sd: bool = False,
13+
apply_ls: bool = False,
14+
apply_svls: bool = False,
15+
edge_weight: float = None,
16+
class_weights: torch.Tensor = None,
17+
**kwargs,
1218
) -> None:
1319
"""Intersection over union loss.
1420
1521
Optionally applies weights at the object edges and classes.
1622
1723
Parameters
1824
----------
25+
apply_sd : bool, default=False
26+
If True, Spectral decoupling regularization will be applied to the
27+
loss matrix.
28+
apply_ls : bool, default=False
29+
If True, Label smoothing will be applied to the target.
30+
apply_svls : bool, default=False
31+
If True, spatially varying label smoothing will be applied to the target
1932
edge_weight : float, default=none
2033
Weight that is added to object borders.
2134
class_weights : torch.Tensor, default=None
2235
Class weights. A tensor of shape (n_classes,).
2336
"""
24-
super().__init__(class_weights, edge_weight)
25-
self.eps = 1e-6
37+
super().__init__(apply_sd, apply_ls, apply_svls, class_weights, edge_weight)
38+
self.eps = 1e-8
2639

2740
def forward(
2841
self,
2942
yhat: torch.Tensor,
3043
target: torch.Tensor,
3144
target_weight: torch.Tensor = None,
32-
**kwargs
45+
**kwargs,
3346
) -> torch.Tensor:
3447
"""Compute the IoU loss.
3548
@@ -48,13 +61,27 @@ def forward(
4861
Computed IoU loss (scalar).
4962
"""
5063
yhat_soft = F.softmax(yhat, dim=1)
51-
target_one_hot = tensor_one_hot(target, n_classes=yhat.shape[1])
64+
num_classes = yhat.shape[1]
65+
target_one_hot = tensor_one_hot(target, n_classes=num_classes)
5266
assert target_one_hot.shape == yhat.shape
5367

68+
if self.apply_svls:
69+
target_one_hot = self.apply_svls_to_target(
70+
target_one_hot, num_classes, **kwargs
71+
)
72+
73+
if self.apply_ls:
74+
target_one_hot = self.apply_ls_to_target(
75+
target_one_hot, num_classes, **kwargs
76+
)
77+
5478
intersection = torch.sum(yhat_soft * target_one_hot, 1) # to (B, H, W)
5579
union = torch.sum(yhat_soft + target_one_hot, 1) # to (B, H, W)
5680
iou = intersection / union.clamp_min(self.eps)
5781

82+
if self.apply_sd:
83+
iou = self.apply_spectral_decouple(iou, yhat)
84+
5885
if self.class_weights is not None:
5986
iou = self.apply_class_weights(iou, target)
6087

cellseg_models_pytorch/losses/criterions/mse.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,38 @@
88

99
class MSE(WeightedBaseLoss):
1010
def __init__(
11-
self, edge_weight: float = None, class_weights: torch.Tensor = None, **kwargs
11+
self,
12+
apply_sd: bool = False,
13+
apply_ls: bool = False,
14+
apply_svls: bool = False,
15+
edge_weight: float = None,
16+
class_weights: torch.Tensor = None,
17+
**kwargs,
1218
) -> None:
1319
"""MSE-loss.
1420
1521
Parameters
1622
----------
23+
apply_sd : bool, default=False
24+
If True, Spectral decoupling regularization will be applied to the
25+
loss matrix.
26+
apply_ls : bool, default=False
27+
If True, Label smoothing will be applied to the target.
28+
apply_svls : bool, default=False
29+
If True, spatially varying label smoothing will be applied to the target
1730
edge_weight : float, default=none
1831
Weight that is added to object borders.
1932
class_weights : torch.Tensor, default=None
2033
Class weights. A tensor of shape (n_classes,).
2134
"""
22-
super().__init__(class_weights, edge_weight)
35+
super().__init__(apply_sd, apply_ls, apply_svls, class_weights, edge_weight)
2336

2437
def forward(
2538
self,
2639
yhat: torch.Tensor,
2740
target: torch.Tensor,
2841
target_weight: torch.Tensor = None,
29-
**kwargs
42+
**kwargs,
3043
) -> torch.Tensor:
3144
"""Compute the MSE-loss.
3245
@@ -45,15 +58,30 @@ def forward(
4558
Computed MSE loss (scalar).
4659
"""
4760
target_one_hot = target
61+
num_classes = yhat.shape[1]
62+
4863
if target.size() != yhat.size():
4964
if target.dtype == torch.float32:
5065
target_one_hot = target.unsqueeze(1)
5166
else:
52-
target_one_hot = tensor_one_hot(target, yhat.shape[1])
67+
target_one_hot = tensor_one_hot(target, num_classes)
68+
69+
if self.apply_svls:
70+
target_one_hot = self.apply_svls_to_target(
71+
target_one_hot, num_classes, **kwargs
72+
)
73+
74+
if self.apply_ls:
75+
target_one_hot = self.apply_ls_to_target(
76+
target_one_hot, num_classes, **kwargs
77+
)
5378

5479
mse = F.mse_loss(yhat, target_one_hot, reduction="none") # (B, C, H, W)
5580
mse = torch.mean(mse, dim=1) # to (B, H, W)
5681

82+
if self.apply_sd:
83+
mse = self.apply_spectral_decouple(mse, yhat)
84+
5785
if self.class_weights is not None:
5886
mse = self.apply_class_weights(mse, target)
5987

0 commit comments

Comments
 (0)