@@ -41,7 +41,7 @@ def get_scores(self, ignore_first_label=False, ignore_last_label=False):
41
41
return iou , miou
42
42
43
43
44
- def evaluate (model , testloader , criterion , num_classes : int , device , amp_enabled : bool ):
44
+ def evaluate (model , testloader , criterion , num_classes : int , amp_enabled : bool , device , eval_fps = True ):
45
45
model .eval ()
46
46
47
47
# Evaluate
@@ -55,12 +55,16 @@ def evaluate(model, testloader, criterion, num_classes: int, device, amp_enabled
55
55
56
56
# 예측
57
57
with torch .cuda .amp .autocast (enabled = amp_enabled ):
58
- torch .cuda .synchronize ()
59
- start_time = time .time ()
60
- with torch .no_grad ():
61
- output = model (image )
62
- torch .cuda .synchronize ()
63
- inference_time += time .time () - start_time
58
+ if eval_fps :
59
+ torch .cuda .synchronize ()
60
+ start_time = time .time ()
61
+ with torch .no_grad ():
62
+ output = model (image )
63
+ torch .cuda .synchronize ()
64
+ inference_time += time .time () - start_time
65
+ else :
66
+ with torch .no_grad ():
67
+ output = model (image )
64
68
65
69
# validation loss를 모두 합침
66
70
val_loss += criterion (output , target ).item ()
@@ -79,8 +83,11 @@ def evaluate(model, testloader, criterion, num_classes: int, device, amp_enabled
79
83
val_loss /= len (testloader )
80
84
81
85
# 추론 시간과 fps를 계산 (추론 시간 단위: sec)
82
- inference_time /= len (testloader .dataset )
83
- fps = 1 / inference_time
86
+ if eval_fps :
87
+ inference_time /= len (testloader .dataset )
88
+ fps = 1 / inference_time
89
+ else :
90
+ fps = 0
84
91
85
92
return val_loss , iou , miou , fps
86
93
@@ -103,7 +110,8 @@ def evaluate(model, testloader, criterion, num_classes: int, device, amp_enabled
103
110
criterion = nn .CrossEntropyLoss ()
104
111
105
112
# 모델 평가
106
- val_loss , iou , miou , fps = evaluate (model , testloader , criterion , config [config ['model' ]]['num_classes' ], device )
113
+ val_loss , iou , miou , fps = evaluate (model , testloader , criterion , config ['dataset' ]['num_classes' ],
114
+ config ['amp_enabled' ], device )
107
115
108
116
# 평가 결과를 csv 파일로 저장
109
117
os .makedirs ('result' , exist_ok = True )
0 commit comments