Skip to content

Commit 501f4df

Browse files
committed
eval: Refactor to simplify EvaluationMetrics class
1 parent e12a363 commit 501f4df

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

eval.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import csv
22
import os
33
import time
4-
from typing import List
54

65
import numpy as np
76
import sklearn.metrics
@@ -14,9 +13,9 @@
1413

1514

1615
class EvaluationMetrics:
17-
def __init__(self, labels: List[int]):
18-
self.labels = labels
19-
self.confusion_matrix = np.zeros((len(labels), len(labels)))
16+
def __init__(self, num_classes: int):
17+
self.labels = list(range(num_classes))
18+
self.confusion_matrix = np.zeros((num_classes, num_classes))
2019

2120
def update_matrix(self, gt_batch: torch.Tensor, pred_batch: torch.Tensor):
2221
assert gt_batch.shape[0] == pred_batch.shape[0]
@@ -46,7 +45,7 @@ def evaluate(model, testloader, criterion, num_classes: int, device):
4645
model.eval()
4746

4847
# Evaluate
49-
metrics = EvaluationMetrics(list(range(num_classes)))
48+
metrics = EvaluationMetrics(num_classes)
5049
val_loss = 0
5150
inference_time = 0
5251
for images, masks in tqdm.tqdm(testloader, desc='Eval', leave=False):

0 commit comments

Comments
 (0)