Skip to content

Commit dd356cc

Browse files
committed
fix(losses): fix sce loss buggyness.
1 parent 545959e commit dd356cc

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

cellseg_models_pytorch/losses/criterions/sce.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,24 +81,15 @@ def forward(
8181

8282
cross_entropy = -torch.sum(forward, dim=1) # to (B, H, W)
8383
reverse_cross_entropy = -torch.sum(reverse, dim=1) # to (B, H, W)
84+
loss = self.alpha * cross_entropy + self.beta * reverse_cross_entropy
85+
86+
if self.apply_sd:
87+
loss = self.apply_spectral_decouple(loss, yhat)
8488

8589
if self.class_weights is not None:
86-
cross_entropy = self.apply_class_weights(cross_entropy, target)
87-
reverse_cross_entropy = self.apply_class_weights(
88-
reverse_cross_entropy, target
89-
)
90+
loss = self.apply_class_weights(loss, target)
9091

9192
if self.edge_weight is not None:
92-
cross_entropy = self.apply_edge_weights(cross_entropy, target_weight)
93-
reverse_cross_entropy = self.apply_edge_weights(
94-
reverse_cross_entropy, target_weight
95-
)
96-
97-
loss = (
98-
self.alpha * cross_entropy.mean() + self.beta * reverse_cross_entropy.mean()
99-
)
100-
101-
if self.apply_sd:
102-
loss = self.apply_spectral_decouple(loss, yhat)
93+
loss = self.apply_edge_weights(loss, target_weight)
10394

104-
return loss
95+
return loss.mean()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Fixes
2+
3+
- Symmetric CE loss fixed.

0 commit comments

Comments
 (0)