-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy patheval_depth
executable file
·121 lines (95 loc) · 4.56 KB
/
eval_depth
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from argparse import ArgumentParser
import datasets
from tqdm import tqdm
import torchmetrics
import matplotlib.pyplot as plt
VOID_VALUE = 255
def main():
parser = ArgumentParser()
parser.add_argument('--dataset', type=str, default='Rellis3DClouds')
parser.add_argument('--weights', type=str, default='deeplabv3_resnet101.pth')
parser.add_argument('--output', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--n_workers', type=int, default=os.cpu_count())
parser.add_argument('--n_samples', type=int, default=None)
parser.add_argument('--vis_preds', action='store_true')
args = parser.parse_args()
print(args)
data_fields = [f[1:-1] for f in ['_x_', '_y_', '_z_', '_intensity_', '_depth_'] if f in args.weights]
print('Model takes as input: %s' % ','.join(data_fields))
Dataset = eval('datasets.%s' % args.dataset)
# lidar_beams_step = 2 in order to have horizontal resolution = 1024 (instead of 2048 as in Rellis data)
valid_dataset = Dataset(split='val', lidar_beams_step=2 if 'Rellis' in args.dataset else None,
output=args.output,
fields=data_fields, num_samples=args.n_samples,
labels_mode='labels')
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_workers)
# --------------Load and set model and optimizer-------------------------------------
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
if not os.path.exists(args.weights):
model = torch.load(os.path.join(os.path.dirname(__file__), '../../config/weights/depth_cloud/', args.weights))
else:
model = torch.load(args.weights)
model = model.to(device)
# ----------------Evaluation----------------------------------------------------------
# IoU/Jaccard score - https://en.wikipedia.org/wiki/Jaccard_index
metric_fn = torchmetrics.JaccardIndex(num_classes=len(valid_dataset.non_bg_classes),
ignore_index=VOID_VALUE,
multilabel=False).to(device)
# validation epoch
metrics = []
model = model.eval()
for itr, sample in tqdm(enumerate(valid_loader)):
inpt, labels = sample
inpt, labels = inpt.to(device), labels.to(device)
with torch.no_grad():
pred = model(inpt)['out'] # make prediction
N, C, H, W = pred.shape
assert labels.shape == (N, H, W)
mask = labels != VOID_VALUE
pred = torch.softmax(pred, dim=1)
pred = pred * mask.unsqueeze(1)
labels = labels * mask
metric_sample = metric_fn(pred, labels.long())
metrics.append(metric_sample.cpu().numpy())
if itr % 100 == 0:
print('mIoU so far (iter=%d): %.3f' % (itr, np.mean(metrics)))
metric = np.mean(metrics)
print('Validation metric: %f' % metric)
if args.vis_preds:
for _ in range(5):
# Use the current trained model and visualize a prediction
model = model.eval()
inpt, label = valid_dataset[np.random.choice(range(len(valid_dataset)))]
inpt = torch.from_numpy(inpt[None]).to(device)
label = torch.from_numpy(label[None]).to(device)
with torch.no_grad():
pred = model(inpt)['out']
pred = pred.squeeze(0).cpu().numpy()
label = label.squeeze(0).cpu().numpy()
color_pred = valid_dataset.label_to_color(pred)
color_gt = valid_dataset.label_to_color(label)
power = 16
depth_img = np.copy(inpt[-1].squeeze(0).cpu().numpy()) # depth
depth_img[depth_img > 0] = depth_img[depth_img > 0] ** (1 / power)
depth_img[depth_img > 0] = (depth_img[depth_img > 0] - depth_img[depth_img > 0].min()) / \
(depth_img[depth_img > 0].max() - depth_img[depth_img > 0].min())
plt.figure(figsize=(20, 10))
plt.subplot(3, 1, 1)
plt.imshow(color_pred)
plt.title('Prediction')
plt.subplot(3, 1, 2)
plt.imshow(color_gt)
plt.title('Ground truth')
plt.subplot(3, 1, 3)
plt.imshow(depth_img)
plt.title('Depth image')
plt.show()
if __name__ == '__main__':
main()