Skip to content

Commit 35f8134

Browse files
committed
fix(losses): Fix sce bug
1 parent 30de27a commit 35f8134

File tree

1 file changed

+9
-2
lines changed
  • cellseg_models_pytorch/losses/criterions

1 file changed

+9
-2
lines changed

cellseg_models_pytorch/losses/criterions/sce.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,15 @@ def forward(
7373
yhat_soft = F.softmax(yhat, dim=1) + self.eps
7474
assert target_one_hot.shape == yhat.shape
7575

76-
yhat = torch.clamp(yhat_soft, min=1e-7, max=1.0)
77-
target_one_hot = torch.clamp(target_one_hot, min=1e-4, max=1.0)
76+
if self.apply_svls:
77+
target_one_hot = self.apply_svls_to_target(
78+
target_one_hot, num_classes, **kwargs
79+
)
80+
81+
if self.apply_ls:
82+
target_one_hot = self.apply_ls_to_target(
83+
target_one_hot, num_classes, **kwargs
84+
)
7885

7986
forward = target_one_hot * torch.log(yhat_soft)
8087
reverse = yhat_soft * torch.log(target_one_hot)

0 commit comments

Comments
 (0)