Skip to content

Commit e228a5f

Browse files
committed
feat(metrics): opt to ret absent cls metrix as -1
1 parent 144a045 commit e228a5f

File tree

1 file changed

+99
-12
lines changed

1 file changed

+99
-12
lines changed

cellseg_models_pytorch/metrics/functional.py

Lines changed: 99 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,21 @@ def aggregated_jaccard_index(
311311
return aji
312312

313313

314+
def _absent_inds(true: np.ndarray, pred: np.ndarray, num_classes: int) -> np.ndarray:
315+
"""Get the class indices that are not present in either `true` or `pred`."""
316+
t = np.unique(true)
317+
p = np.unique(pred)
318+
not_pres = np.setdiff1d(np.arange(num_classes), np.union1d(t, p))
319+
320+
return not_pres
321+
322+
314323
def iou_multiclass(
315-
true: np.ndarray, pred: np.ndarray, num_classes: int, eps: float = 1e-8
324+
true: np.ndarray,
325+
pred: np.ndarray,
326+
num_classes: int,
327+
eps: float = 1e-8,
328+
clamp_absent: bool = True,
316329
) -> np.ndarray:
317330
"""Compute multi-class intersection over union for semantic segmentation masks.
318331
@@ -326,6 +339,9 @@ def iou_multiclass(
326339
Number of classes in the training dataset.
327340
eps : float, default=1e-8:
328341
Epsilon to avoid zero div errors.
342+
clamp_absent : bool, default=True
343+
If a class is not present in either true or pred, the value of that ix
344+
in the result array will be clamped to -1.0.
329345
330346
Returns
331347
-------
@@ -337,11 +353,21 @@ def iou_multiclass(
337353
fp = fp.diagonal()
338354
fn = fn.diagonal()
339355

340-
return tp / (tp + fp + fn + eps)
356+
iou = tp / (tp + fp + fn + eps)
357+
358+
if clamp_absent:
359+
not_pres = _absent_inds(true, pred, num_classes)
360+
iou[not_pres] = -1.0
361+
362+
return iou
341363

342364

343365
def accuracy_multiclass(
344-
true: np.ndarray, pred: np.ndarray, num_classes: int, eps: float = 1e-8
366+
true: np.ndarray,
367+
pred: np.ndarray,
368+
num_classes: int,
369+
eps: float = 1e-8,
370+
clamp_absent: bool = True,
345371
) -> np.ndarray:
346372
"""Compute multi-class accuracy for semantic segmentation masks.
347373
@@ -355,6 +381,9 @@ def accuracy_multiclass(
355381
Number of classes in the training dataset.
356382
eps : float, default=1e-8:
357383
Epsilon to avoid zero div errors.
384+
clamp_absent: bool = True
385+
If a class is not present in either true or pred, the value of that ix
386+
in the result array will be clamped to -1.0.
358387
359388
Returns
360389
-------
@@ -367,11 +396,21 @@ def accuracy_multiclass(
367396
fn = fn.diagonal()
368397
tn = np.prod(true.shape) - (tp + fn + fp)
369398

370-
return (tp + tn) / (tp + fp + fn + tn + eps)
399+
accuracy = (tp + tn) / (tp + fp + fn + tn + eps)
400+
401+
if clamp_absent:
402+
not_pres = _absent_inds(true, pred, num_classes)
403+
accuracy[not_pres] = -1.0
404+
405+
return accuracy
371406

372407

373408
def f1score_multiclass(
374-
true: np.ndarray, pred: np.ndarray, num_classes: int, eps: float = 1e-8
409+
true: np.ndarray,
410+
pred: np.ndarray,
411+
num_classes: int,
412+
eps: float = 1e-8,
413+
clamp_absent: bool = True,
375414
) -> np.ndarray:
376415
"""Compute multi-class f1-score for semantic segmentation masks.
377416
@@ -385,6 +424,9 @@ def f1score_multiclass(
385424
Number of classes in the training dataset.
386425
eps : float, default=1e-8:
387426
Epsilon to avoid zero div errors.
427+
clamp_absent: bool = True
428+
If a class is not present in either true or pred, the value of that ix
429+
in the result array will be clamped to -1.0.
388430
389431
Returns
390432
-------
@@ -396,11 +438,21 @@ def f1score_multiclass(
396438
fp = fp.diagonal()
397439
fn = fn.diagonal()
398440

399-
return tp / (0.5 * fp + 0.5 * fn + tp + eps)
441+
f1 = tp / (0.5 * fp + 0.5 * fn + tp + eps)
442+
443+
if clamp_absent:
444+
not_pres = _absent_inds(true, pred, num_classes)
445+
f1[not_pres] = -1.0
446+
447+
return f1
400448

401449

402450
def dice_multiclass(
403-
true: np.ndarray, pred: np.ndarray, num_classes: int, eps: float = 1e-8
451+
true: np.ndarray,
452+
pred: np.ndarray,
453+
num_classes: int,
454+
eps: float = 1e-8,
455+
clamp_absent: bool = True,
404456
) -> np.ndarray:
405457
"""Compute multi-class dice for semantic segmentation masks.
406458
@@ -414,6 +466,9 @@ def dice_multiclass(
414466
Number of classes in the training dataset.
415467
eps : float, default=1e-8:
416468
Epsilon to avoid zero div errors.
469+
clamp_absent: bool = True
470+
If a class is not present in either true or pred, the value of that ix
471+
in the result array will be clamped to -1.0.
417472
418473
Returns
419474
-------
@@ -425,11 +480,21 @@ def dice_multiclass(
425480
fp = fp.diagonal()
426481
fn = fn.diagonal()
427482

428-
return 2 * tp / (2 * tp + fp + fn + eps)
483+
dice = 2 * tp / (2 * tp + fp + fn + eps)
484+
485+
if clamp_absent:
486+
not_pres = _absent_inds(true, pred, num_classes)
487+
dice[not_pres] = -1.0
488+
489+
return dice
429490

430491

431492
def sensitivity_multiclass(
432-
true: np.ndarray, pred: np.ndarray, num_classes: int, eps: float = 1e-8
493+
true: np.ndarray,
494+
pred: np.ndarray,
495+
num_classes: int,
496+
eps: float = 1e-8,
497+
clamp_absent: bool = True,
433498
) -> np.ndarray:
434499
"""Compute multi-class sensitivity for semantic segmentation masks.
435500
@@ -443,6 +508,9 @@ def sensitivity_multiclass(
443508
Number of classes in the training dataset.
444509
eps : float, default=1e-8:
445510
Epsilon to avoid zero div errors.
511+
clamp_absent: bool = True
512+
If a class is not present in either true or pred, the value of that ix
513+
in the result array will be clamped to -1.0.
446514
447515
Returns
448516
-------
@@ -454,11 +522,21 @@ def sensitivity_multiclass(
454522
fp = fp.diagonal()
455523
fn = fn.diagonal()
456524

457-
return tp / (tp + fn + eps)
525+
sensitivity = tp / (tp + fn + eps)
526+
527+
if clamp_absent:
528+
not_pres = _absent_inds(true, pred, num_classes)
529+
sensitivity[not_pres] = -1.0
530+
531+
return sensitivity
458532

459533

460534
def specificity_multiclass(
461-
true: np.ndarray, pred: np.ndarray, num_classes: int, eps: float = 1e-8
535+
true: np.ndarray,
536+
pred: np.ndarray,
537+
num_classes: int,
538+
eps: float = 1e-8,
539+
clamp_absent: bool = True,
462540
) -> np.ndarray:
463541
"""Compute multi-class specificity for semantic segmentation masks.
464542
@@ -472,6 +550,9 @@ def specificity_multiclass(
472550
Number of classes in the training dataset.
473551
eps : float, default=1e-8:
474552
Epsilon to avoid zero div errors.
553+
clamp_absent: bool = True
554+
If a class is not present in either true or pred, the value of that ix
555+
in the result array will be clamped to -1.0.
475556
476557
Returns
477558
-------
@@ -483,4 +564,10 @@ def specificity_multiclass(
483564
fp = fp.diagonal()
484565
fn = fn.diagonal()
485566

486-
return tp / (tp + fp + eps)
567+
specificity = tp / (tp + fp + eps)
568+
569+
if clamp_absent:
570+
not_pres = _absent_inds(true, pred, num_classes)
571+
specificity[not_pres] = -1.0
572+
573+
return specificity

0 commit comments

Comments
 (0)