Skip to content

Commit 1c45bfe

Browse files
sanskar107naruarjunNarayanan E.R
authored
Distributed training with PyTorch: PointPillars on Waymo (#353)
Added support for distributed training for PyTorch ObjectDetection pipeline. Modify launch script to spawn multiple process for multi-gpu training. Update Object Detection Pipeline to support multiple gpus. Train PointPillars on large scale datasets like Waymo. Co-authored-by: naruarjun <naruarjun@gmail.com> Co-authored-by: Narayanan E.R <narayananr@iiitd.ac.in>
1 parent c581efe commit 1c45bfe

16 files changed

+397
-145
lines changed

ml3d/configs/pointpillars_waymo.yml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ dataset:
22
name: Waymo
33
dataset_path: # path/to/your/dataset
44
cache_dir: ./logs/cache
5-
steps_per_epoch_train: 5000
5+
steps_per_epoch_train: 4000
66

77
model:
88
name: PointPillars
@@ -31,7 +31,7 @@ model:
3131
max_voxels: [32000, 32000]
3232

3333
voxel_encoder:
34-
in_channels: 5
34+
in_channels: 4
3535
feat_channels: [64]
3636
voxel_size: *vsize
3737

@@ -43,7 +43,7 @@ model:
4343
in_channels: 64
4444
out_channels: [64, 128, 256]
4545
layer_nums: [3, 5, 5]
46-
layer_strides: [2, 2, 2]
46+
layer_strides: [1, 2, 2]
4747

4848
neck:
4949
in_channels: [64, 128, 256]
@@ -62,17 +62,18 @@ model:
6262
[-74.88, -74.88, 0, 74.88, 74.88, 0],
6363
]
6464
sizes: [
65-
[2.08, 4.73, 1.77], # car
66-
[0.84, 1.81, 1.77], # cyclist
67-
[0.84, 0.91, 1.74] # pedestrian
65+
[2.08, 4.73, 1.77], # VEHICLE
66+
[0.84, 1.81, 1.77], # CYCLIST
67+
[0.84, 0.91, 1.74] # PEDESTRIAN
6868
]
6969
dir_offset: 0.7854
7070
rotations: [0, 1.57]
7171
iou_thr: [[0.4, 0.55], [0.3, 0.5], [0.3, 0.5]]
7272

7373
augment:
7474
PointShuffle: True
75-
ObjectRangeFilter: True
75+
ObjectRangeFilter:
76+
point_cloud_range: [-74.88, -74.88, -2, 74.88, 74.88, 4]
7677
ObjectSample:
7778
min_points_dict:
7879
VEHICLE: 5
@@ -88,7 +89,7 @@ pipeline:
8889
name: ObjectDetection
8990
test_compute_metric: true
9091
batch_size: 6
91-
val_batch_size: 1
92+
val_batch_size: 6
9293
test_batch_size: 1
9394
save_ckpt_freq: 5
9495
max_epoch: 200
@@ -102,7 +103,7 @@ pipeline:
102103
weight_decay: 0.01
103104

104105
# evaluation properties
105-
overlaps: [0.5, 0.5, 0.7]
106+
overlaps: [0.5, 0.5, 0.5]
106107
difficulties: [0, 1, 2]
107108
summary:
108109
record_for: []

ml3d/datasets/augment/augmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def ObjectSample(self, data, db_boxes_dict, sample_dict):
493493
sampled_points = np.concatenate(
494494
[box.points_inside_box for box in sampled], axis=0)
495495
points = remove_points_in_boxes(points, sampled)
496-
points = np.concatenate([sampled_points, points], axis=0)
496+
points = np.concatenate([sampled_points[:, :4], points], axis=0)
497497

498498
return {
499499
'point': points,

ml3d/datasets/utils/operations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import math
55
from scipy.spatial import ConvexHull
66

7-
from ...metrics import iou_bev
7+
from open3d.ml.contrib import iou_bev_cpu as iou_bev
88

99

1010
def create_3D_rotations(axis, angle):

ml3d/datasets/waymo.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def __init__(self,
2525
name='Waymo',
2626
cache_dir='./logs/cache',
2727
use_cache=False,
28-
val_split=3,
2928
**kwargs):
3029
"""Initialize the function by passing the dataset and other details.
3130
@@ -34,7 +33,6 @@ def __init__(self,
3433
name: The name of the dataset (Waymo in this case).
3534
cache_dir: The directory where the cache is stored.
3635
use_cache: Indicates if the dataset should be cached.
37-
val_split: The split value to get a set of images for training, validation, for testing.
3836
3937
Returns:
4038
class: The corresponding class.
@@ -43,7 +41,6 @@ def __init__(self,
4341
name=name,
4442
cache_dir=cache_dir,
4543
use_cache=use_cache,
46-
val_split=val_split,
4744
**kwargs)
4845

4946
cfg = self.cfg
@@ -52,22 +49,27 @@ def __init__(self,
5249
self.dataset_path = cfg.dataset_path
5350
self.num_classes = 4
5451
self.label_to_names = self.get_label_to_names()
52+
self.shuffle = kwargs.get('shuffle', False)
5553

5654
self.all_files = sorted(
5755
glob(join(cfg.dataset_path, 'velodyne', '*.bin')))
5856
self.train_files = []
5957
self.val_files = []
58+
self.test_files = []
6059

6160
for f in self.all_files:
62-
idx = Path(f).name.replace('.bin', '')[:3]
63-
idx = int(idx)
64-
if idx < cfg.val_split:
61+
if 'train' in f:
6562
self.train_files.append(f)
66-
else:
63+
elif 'val' in f:
6764
self.val_files.append(f)
68-
69-
self.test_files = glob(
70-
join(cfg.dataset_path, 'testing', 'velodyne', '*.bin'))
65+
elif 'test' in f:
66+
self.test_files.append(f)
67+
else:
68+
log.warning(
69+
f"Skipping {f}, prefix must be one of train, test or val.")
70+
if self.shuffle:
71+
log.info("Shuffling training files...")
72+
self.rng.shuffle(self.train_files)
7173

7274
@staticmethod
7375
def get_label_to_names():
@@ -90,18 +92,21 @@ def read_lidar(path):
9092
"""Reads lidar data from the path provided.
9193
9294
Returns:
93-
A data object with lidar information.
95+
pc: pointcloud data with shape [N, 6], where
96+
the format is xyzRGB.
9497
"""
95-
assert Path(path).exists()
96-
9798
return np.fromfile(path, dtype=np.float32).reshape(-1, 6)
9899

99100
@staticmethod
100101
def read_label(path, calib):
101-
"""Reads labels of bound boxes.
102+
"""Reads labels of bounding boxes.
103+
104+
Args:
105+
path: The path to the label file.
106+
calib: Calibration as returned by read_calib().
102107
103108
Returns:
104-
The data objects with bound boxes information.
109+
The data objects with bounding boxes information.
105110
"""
106111
if not Path(path).exists():
107112
return None
@@ -131,24 +136,22 @@ def read_calib(path):
131136
Returns:
132137
The camera and the camera image used in calibration.
133138
"""
134-
assert Path(path).exists()
135-
136139
with open(path, 'r') as f:
137140
lines = f.readlines()
138141
obj = lines[0].strip().split(' ')[1:]
139-
P0 = np.array(obj, dtype=np.float32)
142+
unused_P0 = np.array(obj, dtype=np.float32)
140143

141144
obj = lines[1].strip().split(' ')[1:]
142-
P1 = np.array(obj, dtype=np.float32)
145+
unused_P1 = np.array(obj, dtype=np.float32)
143146

144147
obj = lines[2].strip().split(' ')[1:]
145148
P2 = np.array(obj, dtype=np.float32)
146149

147150
obj = lines[3].strip().split(' ')[1:]
148-
P3 = np.array(obj, dtype=np.float32)
151+
unused_P3 = np.array(obj, dtype=np.float32)
149152

150153
obj = lines[4].strip().split(' ')[1:]
151-
P4 = np.array(obj, dtype=np.float32)
154+
unused_P4 = np.array(obj, dtype=np.float32)
152155

153156
obj = lines[5].strip().split(' ')[1:]
154157
R0 = np.array(obj, dtype=np.float32).reshape(3, 3)
@@ -162,7 +165,7 @@ def read_calib(path):
162165
Tr_velo_to_cam = Waymo._extend_matrix(Tr_velo_to_cam)
163166

164167
world_cam = np.transpose(rect_4x4 @ Tr_velo_to_cam)
165-
cam_img = np.transpose(P2)
168+
cam_img = np.transpose(np.vstack((P2.reshape(3, 4), [0, 0, 0, 1])))
166169

167170
return {'world_cam': world_cam, 'cam_img': cam_img}
168171

@@ -209,7 +212,7 @@ def get_split_list(self, split):
209212
else:
210213
raise ValueError("Invalid split {}".format(split))
211214

212-
def is_tested():
215+
def is_tested(attr):
213216
"""Checks if a datum in the dataset has been tested.
214217
215218
Args:
@@ -219,16 +222,16 @@ def is_tested():
219222
If the datum attribute is tested, then return the path where the
220223
attribute is stored; else, returns false.
221224
"""
222-
pass
225+
raise NotImplementedError()
223226

224-
def save_test_result():
227+
def save_test_result(results, attr):
225228
"""Saves the output of a model.
226229
227230
Args:
228231
results: The output of a model for the datum associated with the attribute passed.
229232
attr: The attributes that correspond to the outputs passed in results.
230233
"""
231-
pass
234+
raise NotImplementedError()
232235

233236

234237
class WaymoSplit():
@@ -273,11 +276,9 @@ def get_attr(self, idx):
273276

274277

275278
class Object3d(BEVBox3D):
276-
"""The class stores details that are object-specific, such as bounding box
277-
coordinates, occlusion and so on.
278-
"""
279279

280280
def __init__(self, center, size, label, calib):
281+
# ground truth files doesn't have confidence value.
281282
confidence = float(label[15]) if label.__len__() == 16 else -1.0
282283

283284
world_cam = calib['world_cam']

ml3d/torch/dataloaders/concat_batcher.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pickle
55
import torch
66
import yaml
7+
import math
78
from os import listdir
89
from os.path import exists, join, isdir
910

@@ -434,6 +435,22 @@ def to(self, device):
434435
self.feat = [feat.to(device) for feat in self.feat]
435436
self.label = [label.to(device) for label in self.label]
436437

438+
@staticmethod
439+
def scatter(batch, num_gpu):
440+
batch_size = len(batch.batch_lengths)
441+
442+
new_batch_size = math.ceil(batch_size / num_gpu)
443+
batches = [SparseConvUnetBatch([]) for _ in range(num_gpu)]
444+
for i in range(num_gpu):
445+
start = new_batch_size * i
446+
end = min(new_batch_size * (i + 1), batch_size)
447+
batches[i].point = batch.point[start:end]
448+
batches[i].feat = batch.feat[start:end]
449+
batches[i].label = batch.label[start:end]
450+
batches[i].batch_lengths = batch.batch_lengths[start:end]
451+
452+
return [b for b in batches if len(b.point)] # filter empty batch
453+
437454

438455
class PointTransformerBatch:
439456

@@ -486,7 +503,6 @@ def __init__(self, batches):
486503
self.attr = []
487504

488505
for batch in batches:
489-
self.attr.append(batch['attr'])
490506
data = batch['data']
491507
self.point.append(torch.tensor(data['point'], dtype=torch.float32))
492508
self.labels.append(
@@ -519,6 +535,23 @@ def to(self, device):
519535
if self.bboxes[i] is not None:
520536
self.bboxes[i] = self.bboxes[i].to(device)
521537

538+
@staticmethod
539+
def scatter(batch, num_gpu):
540+
batch_size = len(batch.point)
541+
542+
new_batch_size = math.ceil(batch_size / num_gpu)
543+
batches = [ObjectDetectBatch([]) for _ in range(num_gpu)]
544+
for i in range(num_gpu):
545+
start = new_batch_size * i
546+
end = min(new_batch_size * (i + 1), batch_size)
547+
batches[i].point = batch.point[start:end]
548+
batches[i].labels = batch.labels[start:end]
549+
batches[i].bboxes = batch.bboxes[start:end]
550+
batches[i].bbox_objs = batch.bbox_objs[start:end]
551+
batches[i].attr = batch.attr[start:end]
552+
553+
return [b for b in batches if len(b.point)] # filter empty batch
554+
522555

523556
class ConcatBatcher(object):
524557
"""ConcatBatcher for KPConv."""

ml3d/torch/models/base_model_objdet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, **kwargs):
2525
self.rng = np.random.default_rng(kwargs.get('seed', None))
2626

2727
@abstractmethod
28-
def loss(self, results, inputs):
28+
def get_loss(self, results, inputs):
2929
"""Computes the loss given the network input and outputs.
3030
3131
Args:

ml3d/torch/models/point_pillars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def get_optimizer(self, cfg):
137137
optimizer = torch.optim.AdamW(self.parameters(), **cfg)
138138
return optimizer, None
139139

140-
def loss(self, results, inputs):
140+
def get_loss(self, results, inputs):
141141
scores, bboxes, dirs = results
142142
gt_labels = inputs.labels
143143
gt_bboxes = inputs.bboxes

ml3d/torch/models/point_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def step(self):
183183

184184
return optimizer, scheduler
185185

186-
def loss(self, results, inputs):
186+
def get_loss(self, results, inputs):
187187
if self.mode == "RPN":
188188
return self.rpn.loss(results, inputs)
189189
else:

0 commit comments

Comments
 (0)