@@ -311,8 +311,21 @@ def aggregated_jaccard_index(
311
311
return aji
312
312
313
313
314
+ def _absent_inds (true : np .ndarray , pred : np .ndarray , num_classes : int ) -> np .ndarray :
315
+ """Get the class indices that are not present in either `true` or `pred`."""
316
+ t = np .unique (true )
317
+ p = np .unique (pred )
318
+ not_pres = np .setdiff1d (np .arange (num_classes ), np .union1d (t , p ))
319
+
320
+ return not_pres
321
+
322
+
314
323
def iou_multiclass (
315
- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
324
+ true : np .ndarray ,
325
+ pred : np .ndarray ,
326
+ num_classes : int ,
327
+ eps : float = 1e-8 ,
328
+ clamp_absent : bool = True ,
316
329
) -> np .ndarray :
317
330
"""Compute multi-class intersection over union for semantic segmentation masks.
318
331
@@ -326,6 +339,9 @@ def iou_multiclass(
326
339
Number of classes in the training dataset.
327
340
eps : float, default=1e-8:
328
341
Epsilon to avoid zero div errors.
342
+ clamp_absent : bool, default=True
343
+ If a class is not present in either true or pred, the value of that ix
344
+ in the result array will be clamped to -1.0.
329
345
330
346
Returns
331
347
-------
@@ -337,11 +353,21 @@ def iou_multiclass(
337
353
fp = fp .diagonal ()
338
354
fn = fn .diagonal ()
339
355
340
- return tp / (tp + fp + fn + eps )
356
+ iou = tp / (tp + fp + fn + eps )
357
+
358
+ if clamp_absent :
359
+ not_pres = _absent_inds (true , pred , num_classes )
360
+ iou [not_pres ] = - 1.0
361
+
362
+ return iou
341
363
342
364
343
365
def accuracy_multiclass (
344
- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
366
+ true : np .ndarray ,
367
+ pred : np .ndarray ,
368
+ num_classes : int ,
369
+ eps : float = 1e-8 ,
370
+ clamp_absent : bool = True ,
345
371
) -> np .ndarray :
346
372
"""Compute multi-class accuracy for semantic segmentation masks.
347
373
@@ -355,6 +381,9 @@ def accuracy_multiclass(
355
381
Number of classes in the training dataset.
356
382
eps : float, default=1e-8:
357
383
Epsilon to avoid zero div errors.
384
+ clamp_absent: bool = True
385
+ If a class is not present in either true or pred, the value of that ix
386
+ in the result array will be clamped to -1.0.
358
387
359
388
Returns
360
389
-------
@@ -367,11 +396,21 @@ def accuracy_multiclass(
367
396
fn = fn .diagonal ()
368
397
tn = np .prod (true .shape ) - (tp + fn + fp )
369
398
370
- return (tp + tn ) / (tp + fp + fn + tn + eps )
399
+ accuracy = (tp + tn ) / (tp + fp + fn + tn + eps )
400
+
401
+ if clamp_absent :
402
+ not_pres = _absent_inds (true , pred , num_classes )
403
+ accuracy [not_pres ] = - 1.0
404
+
405
+ return accuracy
371
406
372
407
373
408
def f1score_multiclass (
374
- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
409
+ true : np .ndarray ,
410
+ pred : np .ndarray ,
411
+ num_classes : int ,
412
+ eps : float = 1e-8 ,
413
+ clamp_absent : bool = True ,
375
414
) -> np .ndarray :
376
415
"""Compute multi-class f1-score for semantic segmentation masks.
377
416
@@ -385,6 +424,9 @@ def f1score_multiclass(
385
424
Number of classes in the training dataset.
386
425
eps : float, default=1e-8:
387
426
Epsilon to avoid zero div errors.
427
+ clamp_absent: bool = True
428
+ If a class is not present in either true or pred, the value of that ix
429
+ in the result array will be clamped to -1.0.
388
430
389
431
Returns
390
432
-------
@@ -396,11 +438,21 @@ def f1score_multiclass(
396
438
fp = fp .diagonal ()
397
439
fn = fn .diagonal ()
398
440
399
- return tp / (0.5 * fp + 0.5 * fn + tp + eps )
441
+ f1 = tp / (0.5 * fp + 0.5 * fn + tp + eps )
442
+
443
+ if clamp_absent :
444
+ not_pres = _absent_inds (true , pred , num_classes )
445
+ f1 [not_pres ] = - 1.0
446
+
447
+ return f1
400
448
401
449
402
450
def dice_multiclass (
403
- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
451
+ true : np .ndarray ,
452
+ pred : np .ndarray ,
453
+ num_classes : int ,
454
+ eps : float = 1e-8 ,
455
+ clamp_absent : bool = True ,
404
456
) -> np .ndarray :
405
457
"""Compute multi-class dice for semantic segmentation masks.
406
458
@@ -414,6 +466,9 @@ def dice_multiclass(
414
466
Number of classes in the training dataset.
415
467
eps : float, default=1e-8:
416
468
Epsilon to avoid zero div errors.
469
+ clamp_absent: bool = True
470
+ If a class is not present in either true or pred, the value of that ix
471
+ in the result array will be clamped to -1.0.
417
472
418
473
Returns
419
474
-------
@@ -425,11 +480,21 @@ def dice_multiclass(
425
480
fp = fp .diagonal ()
426
481
fn = fn .diagonal ()
427
482
428
- return 2 * tp / (2 * tp + fp + fn + eps )
483
+ dice = 2 * tp / (2 * tp + fp + fn + eps )
484
+
485
+ if clamp_absent :
486
+ not_pres = _absent_inds (true , pred , num_classes )
487
+ dice [not_pres ] = - 1.0
488
+
489
+ return dice
429
490
430
491
431
492
def sensitivity_multiclass (
432
- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
493
+ true : np .ndarray ,
494
+ pred : np .ndarray ,
495
+ num_classes : int ,
496
+ eps : float = 1e-8 ,
497
+ clamp_absent : bool = True ,
433
498
) -> np .ndarray :
434
499
"""Compute multi-class sensitivity for semantic segmentation masks.
435
500
@@ -443,6 +508,9 @@ def sensitivity_multiclass(
443
508
Number of classes in the training dataset.
444
509
eps : float, default=1e-8:
445
510
Epsilon to avoid zero div errors.
511
+ clamp_absent: bool = True
512
+ If a class is not present in either true or pred, the value of that ix
513
+ in the result array will be clamped to -1.0.
446
514
447
515
Returns
448
516
-------
@@ -454,11 +522,21 @@ def sensitivity_multiclass(
454
522
fp = fp .diagonal ()
455
523
fn = fn .diagonal ()
456
524
457
- return tp / (tp + fn + eps )
525
+ sensitivity = tp / (tp + fn + eps )
526
+
527
+ if clamp_absent :
528
+ not_pres = _absent_inds (true , pred , num_classes )
529
+ sensitivity [not_pres ] = - 1.0
530
+
531
+ return sensitivity
458
532
459
533
460
534
def specificity_multiclass (
461
- true : np .ndarray , pred : np .ndarray , num_classes : int , eps : float = 1e-8
535
+ true : np .ndarray ,
536
+ pred : np .ndarray ,
537
+ num_classes : int ,
538
+ eps : float = 1e-8 ,
539
+ clamp_absent : bool = True ,
462
540
) -> np .ndarray :
463
541
"""Compute multi-class specificity for semantic segmentation masks.
464
542
@@ -472,6 +550,9 @@ def specificity_multiclass(
472
550
Number of classes in the training dataset.
473
551
eps : float, default=1e-8:
474
552
Epsilon to avoid zero div errors.
553
+ clamp_absent: bool = True
554
+ If a class is not present in either true or pred, the value of that ix
555
+ in the result array will be clamped to -1.0.
475
556
476
557
Returns
477
558
-------
@@ -483,4 +564,10 @@ def specificity_multiclass(
483
564
fp = fp .diagonal ()
484
565
fn = fn .diagonal ()
485
566
486
- return tp / (tp + fp + eps )
567
+ specificity = tp / (tp + fp + eps )
568
+
569
+ if clamp_absent :
570
+ not_pres = _absent_inds (true , pred , num_classes )
571
+ specificity [not_pres ] = - 1.0
572
+
573
+ return specificity
0 commit comments