Skip to content

Commit be3fff4

Browse files
committed
eval: Implement parameter for controling fps evaluation
1 parent e881a2a commit be3fff4

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

eval.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_scores(self, ignore_first_label=False, ignore_last_label=False):
4141
return iou, miou
4242

4343

44-
def evaluate(model, testloader, criterion, num_classes: int, device, amp_enabled: bool):
44+
def evaluate(model, testloader, criterion, num_classes: int, amp_enabled: bool, device, eval_fps=True):
4545
model.eval()
4646

4747
# Evaluate
@@ -55,12 +55,16 @@ def evaluate(model, testloader, criterion, num_classes: int, device, amp_enabled
5555

5656
# 예측
5757
with torch.cuda.amp.autocast(enabled=amp_enabled):
58-
torch.cuda.synchronize()
59-
start_time = time.time()
60-
with torch.no_grad():
61-
output = model(image)
62-
torch.cuda.synchronize()
63-
inference_time += time.time() - start_time
58+
if eval_fps:
59+
torch.cuda.synchronize()
60+
start_time = time.time()
61+
with torch.no_grad():
62+
output = model(image)
63+
torch.cuda.synchronize()
64+
inference_time += time.time() - start_time
65+
else:
66+
with torch.no_grad():
67+
output = model(image)
6468

6569
# validation loss를 모두 합침
6670
val_loss += criterion(output, target).item()
@@ -79,8 +83,11 @@ def evaluate(model, testloader, criterion, num_classes: int, device, amp_enabled
7983
val_loss /= len(testloader)
8084

8185
# 추론 시간과 fps를 계산 (추론 시간 단위: sec)
82-
inference_time /= len(testloader.dataset)
83-
fps = 1 / inference_time
86+
if eval_fps:
87+
inference_time /= len(testloader.dataset)
88+
fps = 1 / inference_time
89+
else:
90+
fps = 0
8491

8592
return val_loss, iou, miou, fps
8693

@@ -103,7 +110,8 @@ def evaluate(model, testloader, criterion, num_classes: int, device, amp_enabled
103110
criterion = nn.CrossEntropyLoss()
104111

105112
# 모델 평가
106-
val_loss, iou, miou, fps = evaluate(model, testloader, criterion, config[config['model']]['num_classes'], device)
113+
val_loss, iou, miou, fps = evaluate(model, testloader, criterion, config['dataset']['num_classes'],
114+
config['amp_enabled'], device)
107115

108116
# 평가 결과를 csv 파일로 저장
109117
os.makedirs('result', exist_ok=True)

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161
# 모델 평가
6262
val_loss, _, miou, _ = eval.evaluate(model, testloader, criterion, config['dataset']['num_classes'],
63-
device, config['amp_enabled'])
63+
config['amp_enabled'], device, False)
6464
writer.add_scalar('Validation Loss', val_loss, epoch)
6565
writer.add_scalar('mIoU', miou, epoch)
6666

0 commit comments

Comments
 (0)