Skip to content

Commit 783a3bb

Browse files
author
Bartosz Sułek
committed
Add Pandaset class and config
1 parent 4d6e9dc commit 783a3bb

File tree

2 files changed

+319
-0
lines changed

2 files changed

+319
-0
lines changed

ml3d/configs/randlanet_pandaset.yml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
dataset:
2+
name: Pandaset
3+
dataset_path: path_to_dataset
4+
cache_dir: ./logs/cache
5+
test_result_folder: './logs/test'
6+
training_split: [ '001', '002', '003', '005', '011', '013', '015', '016',
7+
'017', '019', '021', '023', '024', '027', '028', '029',
8+
'030', '032', '033', '034', '035', '037', '038', '039',
9+
'040', '041', '042', '043', '044', '046', '052', '053',
10+
'054', '056', '057', '058', '064', '065', '066', '067',
11+
'070', '071', '072', '073', '077', '078', '080', '084',
12+
'088', '089', '090', '094', '095', '097', '098', '101',
13+
'102', '103', '105', '106', '109', '110', '112', '113'
14+
]
15+
test_split: ['115', '116', '117', '119', '120', '124', '139', '149', '158']
16+
validation_split: ['122', '123']
17+
use_cache: true
18+
sampler:
19+
name: 'SemSegRandomSampler'
20+
model:
21+
name: RandLANet
22+
batcher: DefaultBatcher
23+
num_classes: 39
24+
num_points: 81920
25+
num_neighbors: 16
26+
framework: torch
27+
num_layers: 4
28+
ignored_label_inds: [0]
29+
sub_sampling_ratio: [4, 4, 4, 4]
30+
in_channels: 3
31+
dim_features: 8
32+
dim_output: [16, 64, 128, 256]
33+
grid_size: 0.06
34+
pipeline:
35+
name: SemanticSegmentation
36+
max_epoch: 50
37+
save_ckpt_freq: 5
38+
device: gpu
39+
optimizer:
40+
lr: 0.001
41+
batch_size: 4
42+
main_log_dir: './logs'
43+
logs_dir: './logs'
44+
scheduler_gamma: 0.9886
45+
test_batch_size: 2
46+
train_sum_dir: './logs/training_log'
47+
val_batch_size: 2

ml3d/datasets/pandaset.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
import os
2+
from os.path import join
3+
import numpy as np
4+
import pandas as pd
5+
from pathlib import Path
6+
import logging
7+
8+
from .base_dataset import BaseDataset, BaseDatasetSplit
9+
from ..utils import make_dir, DATASET
10+
11+
log = logging.getLogger(__name__)
12+
13+
class Pandaset(BaseDataset):
14+
""" This class is used to create a dataset based on the Pandaset autonomous
15+
driving dataset.
16+
17+
https://pandaset.org/
18+
19+
The dataset includes 42 semantic classes and covers more than 100 scenes,
20+
each of which is 8 seconds long.
21+
22+
"""
23+
def __init__(self,
24+
dataset_path,
25+
name="Pandaset",
26+
cache_dir="./logs/cache",
27+
use_cache=False,
28+
ignored_label_inds=[],
29+
test_result_folder='./logs/test_log',
30+
test_split=['115', '116', '117', '119', '120', '124', '139', '149', '158'],
31+
training_split=[
32+
'001', '002', '003', '005', '011', '013', '015', '016',
33+
'017', '019', '021', '023', '024', '027', '028', '029',
34+
'030', '032', '033', '034', '035', '037', '038', '039',
35+
'040', '041', '042', '043', '044', '046', '052', '053',
36+
'054', '056', '057', '058', '064', '065', '066', '067',
37+
'070', '071', '072', '073', '077', '078', '080', '084',
38+
'088', '089', '090', '094', '095', '097', '098', '101',
39+
'102', '103', '105', '106', '109', '110', '112', '113'
40+
],
41+
validation_split=['122', '123'],
42+
all_split=['001', '002', '003', '005', '011', '013', '015', '016',
43+
'017', '019', '021', '023', '024', '027', '028', '029',
44+
'030', '032', '033', '034', '035', '037', '038', '039',
45+
'040', '041', '042', '043', '044', '046', '052', '053',
46+
'054', '056', '057', '058', '064', '065', '066', '067',
47+
'069', '070', '071', '072', '073', '077', '078', '080',
48+
'084', '088', '089', '090', '094', '095', '097', '098',
49+
'101', '102', '103', '105', '106', '109', '110', '112',
50+
'113', '115', '116', '117', '119', '120', '122', '123',
51+
'124', '139', '149', '158'],
52+
**kwargs):
53+
54+
"""Initialize the function by passing the dataset and other details.
55+
56+
Args:
57+
dataset_path: The path to the dataset to use.
58+
name: The name of the dataset.
59+
cache_dir: The directory where the cache is stored.
60+
use_cache: Indicates if the dataset should be cached.
61+
ignored_label_inds: A list of labels that should be ignored in the dataset.
62+
Returns:
63+
class: The corresponding class.
64+
"""
65+
super().__init__(dataset_path=dataset_path,
66+
name=name,
67+
cache_dir=cache_dir,
68+
use_cache=use_cache,
69+
ignored_label_inds=ignored_label_inds,
70+
test_result_folder=test_result_folder,
71+
test_split=test_split,
72+
training_split=training_split,
73+
validation_split=validation_split,
74+
all_split=all_split,
75+
**kwargs)
76+
77+
cfg = self.cfg
78+
79+
self.label_to_names = self.get_label_to_names()
80+
self.num_classes = len(self.label_to_names)
81+
self.label_values = np.sort([k for k, v in self.label_to_names.items()])
82+
83+
@staticmethod
84+
def get_label_to_names():
85+
"""Returns a label to names dictionary object.
86+
87+
Returns:
88+
A dict where keys are label numbers and
89+
values are the corresponding names.
90+
"""
91+
label_to_names = {
92+
1: "Reflection",
93+
2: "Vegetation",
94+
3: "Ground",
95+
4: "Road",
96+
5: "Lane Line Marking",
97+
6: "Stop Line Marking",
98+
7: "Other Road Marking",
99+
8: "Sidewalk",
100+
9: "Driveway",
101+
10: "Car",
102+
11: "Pickup Truck",
103+
12: "Medium-sized Truck",
104+
13: "Semi-truck",
105+
14: "Towed Object",
106+
15: "Motorcycle",
107+
16: "Other Vehicle - Construction Vehicle",
108+
17: "Other Vehicle - Uncommon",
109+
18: "Other Vehicle - Pedicab",
110+
19: "Emergency Vehicle",
111+
20: "Bus",
112+
21: "Personal Mobility Device",
113+
22: "Motorized Scooter",
114+
23: "Bicycle",
115+
24: "Train",
116+
25: "Trolley",
117+
26: "Tram / Subway",
118+
27: "Pedestrian",
119+
28: "Pedestrian with Object",
120+
29: "Animals - Bird",
121+
30: "Animals - Other",
122+
31: "Pylons",
123+
32: "Road Barriers",
124+
33: "Signs",
125+
34: "Cones",
126+
35: "Construction Signs",
127+
36: "Temporary Construction Barriers",
128+
37: "Rolling Containers",
129+
38: "Building",
130+
39: "Other Static Object"
131+
}
132+
return label_to_names
133+
134+
def get_split(self, split):
135+
"""Returns a dataset split.
136+
137+
Args:
138+
split: A string identifying the dataset split that is usually one of
139+
'training', 'test', or 'all'.
140+
141+
Returns:
142+
A dataset split object providing the requested subset of the data.
143+
"""
144+
return PandasetSplit(self, split=split)
145+
146+
def get_split_list(self, split):
147+
"""Returns the list of data splits available.
148+
149+
Args:
150+
split: A string identifying the dataset split that is usually one of
151+
'training', 'test', 'validation', or 'all'.
152+
153+
Returns:
154+
A dataset split object providing the requested subset of the data.
155+
156+
Raises:
157+
ValueError: Indicates that the split name passed is incorrect. The split name should be one of
158+
'training', 'test', 'validation', or 'all'.
159+
"""
160+
cfg = self.cfg
161+
dataset_path = cfg.dataset_path
162+
file_list = []
163+
164+
if split in ['train', 'training']:
165+
seq_list = cfg.training_split
166+
elif split in ['test', 'testing']:
167+
seq_list = cfg.test_split
168+
elif split in ['val', 'validation']:
169+
seq_list = cfg.validation_split
170+
elif split in ['all']:
171+
seq_list = cfg.all_split
172+
else:
173+
raise ValueError("Invalid split {}".format(split))
174+
175+
for seq_id in seq_list:
176+
pc_path = join(dataset_path, seq_id, 'lidar')
177+
for f in np.sort(os.listdir(pc_path)):
178+
if f.split('.')[-1] == 'gz':
179+
file_list.append(join(pc_path, f))
180+
181+
return file_list
182+
183+
def is_tested(self, attr):
184+
"""Checks if a datum in the dataset has been tested.
185+
186+
Args:
187+
dataset: The current dataset to which the datum belongs to.
188+
attr: The attribute that needs to be checked.
189+
190+
Returns:
191+
If the dataum attribute is tested, then return the path where the
192+
attribute is stored; else, returns false.
193+
"""
194+
pass
195+
196+
def save_test_result(self, results, attr):
197+
"""Saves the output of a model.
198+
199+
Args:
200+
results: The output of a model for the datum associated with the
201+
attribute passed.
202+
attrs: The attributes that correspond to the outputs passed in
203+
results.
204+
"""
205+
cfg = self.cfg
206+
pred = results['predict_labels']
207+
name = attr['name']
208+
209+
test_path = join(cfg.test_result_folder, 'sequences')
210+
make_dir(test_path)
211+
save_path = join(test_path, name, 'predictions')
212+
make_dir(save_path)
213+
pred = results['predict_labels']
214+
215+
for ign in cfg.ignored_label_inds:
216+
pred[pred >= ign] += 1
217+
218+
store_path = join(save_path, name + '.label')
219+
220+
pred = pred.astype(np.uint32)
221+
pred.tofile(store_path)
222+
223+
224+
class PandasetSplit(BaseDatasetSplit):
225+
"""This class is used to create a split for Pandaset dataset.
226+
227+
Args:
228+
dataset: The dataset to split.
229+
split: A string identifying the dataset split that is usually one of
230+
'training', 'test', 'validation', or 'all'.
231+
**kwargs: The configuration of the model as keyword arguments.
232+
233+
Returns:
234+
A dataset split object providing the requested subset of the data.
235+
"""
236+
def __init__(self, dataset, split='train'):
237+
super().__init__(dataset, split=split)
238+
log.info("Found {} pointclouds for {}".format(len(self.path_list),
239+
split))
240+
241+
def __len__(self):
242+
return len(self.path_list)
243+
244+
def get_data(self, idx):
245+
pc_path = self.path_list[idx]
246+
label_path = pc_path.replace('lidar', 'annotations/semseg')
247+
248+
points = pd.read_pickle(pc_path)
249+
labels = pd.read_pickle(label_path)
250+
251+
intensity = points['i'].to_numpy().astype(np.float32)
252+
points = points.drop(columns=['i', 't', 'd']).to_numpy().astype(np.float32)
253+
labels = labels.to_numpy().astype(np.int32)
254+
255+
data = {
256+
'point': points,
257+
'intensity': intensity,
258+
'label': labels
259+
}
260+
261+
return data
262+
263+
def get_attr(self, idx):
264+
pc_path = self.path_list[idx]
265+
value = (pc_path).split('/')[9]
266+
name = Path(pc_path).name.split('.')[0]
267+
name = value + '_' + name
268+
269+
attr = {'name': name, 'path': pc_path, 'split': self.split}
270+
return attr
271+
272+
DATASET._register_module(Pandaset)

0 commit comments

Comments
 (0)