Skip to content

Commit 66131df

Browse files
committed
PointPillars Waymo distributed training model weights and docs (PR #585)
- Training speedup with GPUs. - Add model weight file + metrics. - Script for running training with SLURM.
1 parent ee6cb1a commit 66131df

21 files changed

+115
-42
lines changed

README.md

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ respective requirements files:
4646
```bash
4747
# To install a compatible version of TensorFlow
4848
pip install -r requirements-tensorflow.txt
49-
# To install a compatible version of PyTorch with CUDA
49+
# To install a compatible version of PyTorch
50+
pip install -r requirements-torch.txt
51+
# To install a compatible version of PyTorch with CUDA on Linux
5052
pip install -r requirements-torch-cuda.txt
5153
```
5254

@@ -338,15 +340,17 @@ The table shows the available models and datasets for the segmentation task and
338340
For the task of object detection, we measure the performance of different methods using the mean average precision (mAP) for bird's eye view (BEV) and 3D.
339341
The table shows the available models and datasets for the object detection task and the respective scores. Each score links to the respective weight file.
340342
For the evaluation, the models were evaluated using the validation subset, according to KITTI's validation criteria. The models were trained for three classes (car, pedestrian and cyclist). The calculated values are the mean value over the mAP of all classes for all difficulty levels.
343+
For the Waymo dataset, the models were trained on three classes (pedestrian, vehicle, cyclist).
341344

342345

343-
| Model / Dataset | KITTI [BEV / 3D] @ 0.70|
344-
|--------------------|---------------|
345-
| PointPillars (tf) | [61.6 / 55.2](https://storage.googleapis.com/open3d-releases/model-zoo/pointpillars_kitti_202012221652utc.zip) |
346-
| PointPillars (torch) | [61.2 / 52.8](https://storage.googleapis.com/open3d-releases/model-zoo/pointpillars_kitti_202012221652utc.pth) |
347-
| PointRCNN (tf) | [78.2 / 65.9](https://storage.googleapis.com/open3d-releases/model-zoo/pointrcnn_kitti_202105071146utc.zip) |
348-
| PointRCNN (torch) | [78.2 / 65.9](https://storage.googleapis.com/open3d-releases/model-zoo/pointrcnn_kitti_202105071146utc.pth) |
346+
| Model / Dataset | KITTI [BEV / 3D] @ 0.70| Waymo (BEV / 3D) @ 0.50 |
347+
|--------------------|------------------------|------------------|
348+
| PointPillars (tf) | [61.6 / 55.2](https://storage.googleapis.com/open3d-releases/model-zoo/pointpillars_kitti_202012221652utc.zip) | - |
349+
| PointPillars (torch) | [61.2 / 52.8](https://storage.googleapis.com/open3d-releases/model-zoo/pointpillars_kitti_202012221652utc.pth) | avg: 61.01 / 48.30 \| [best: 61.47 / 57.55](https://storage.googleapis.com/open3d-releases/model-zoo/pointpillars_waymo_202211200158utc_seed2_gpu16.pth) [^wpp-train] |
350+
| PointRCNN (tf) | [78.2 / 65.9](https://storage.googleapis.com/open3d-releases/model-zoo/pointrcnn_kitti_202105071146utc.zip) | - |
351+
| PointRCNN (torch) | [78.2 / 65.9](https://storage.googleapis.com/open3d-releases/model-zoo/pointrcnn_kitti_202105071146utc.pth) | - |
349352

353+
[^wpp-train]: The avg. metrics are the average of three sets of training runs with 4, 8, 16 and 32 GPUs. Training was for halted after 30 epochs. Model checkpoint is available for the best training run.
350354

351355
#### Training PointRCNN
352356

@@ -402,6 +406,7 @@ For downloading these datasets visit the respective webpages and have a look at
402406
* [Visualize custom data](docs/howtos.md#visualize-custom-data)
403407
* [Adding a new model](docs/howtos.md#adding-a-new-model)
404408
* [Adding a new dataset](docs/howtos.md#adding-a-new-dataset)
409+
* [Distributed training](docs/howtos.md#distributed-training)
405410
* [Visualize and compare input data, ground truth and results in TensorBoard](docs/tensorboard.md)
406411
* [Inference with Intel OpenVINO](docs/openvino.md)
407412

ci/run_ci.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ cmake -DBUNDLE_OPEN3D_ML=ON \
5050
-DGLIBCXX_USE_CXX11_ABI=OFF \
5151
-DBUILD_TENSORFLOW_OPS=ON \
5252
-DBUILD_PYTORCH_OPS=ON \
53-
-DBUILD_GUI=OFF \
54-
-DBUILD_RPC_INTERFACE=OFF \
53+
-DBUILD_GUI=ON \
5554
-DBUILD_UNIT_TESTS=OFF \
5655
-DBUILD_BENCHMARKS=OFF \
5756
-DBUILD_EXAMPLES=OFF \

docs/howtos.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,18 @@ import open3d.ml.torch as ml3d
243243
model = ml3d.models.MyModel()
244244
dataset = ml3d.datasets.MyDataset()
245245
```
246+
247+
## Distributed training (preview)
248+
249+
Open3D-ML currently supports distributed training with PyTorch for object detection on Waymo with the PointPillars model. More comprehensive support for semantic segmentation models will follow shortly.
250+
251+
Distributed training uses the PyTorch Distributed Data Parallel (DDP) module and can be used to distribute training across multiple computer nodes, each with multiple GPUs. Here is a chart of per eopch runtime showing the speedup of sample runs with increasing number of GPUs. The training was run on a cluster containing 4 nodes with 8 RTX 3090 GPUs each.
252+
253+
- Dataset: Waymo v1.3
254+
- Model: PointPillars
255+
- GPU: RTX 3090
256+
- Batch size: 4 per GPU
257+
258+
![PointPillars training on Waymo per epoch training time with number of GPUs](https://user-images.githubusercontent.com/41028320/220750523-57075575-8cc7-4e40-99b0-a4e79995f1ec.png)
259+
260+
See [`scripts/train_scripts/pointpillars_waymo.sh`](../scripts/train_scripts/pointpillars_waymo.sh) for an example SLURM training script for distributed training on two nodes, using four GPUs on each node. The remaining configuration is read from the config file [`pointpillars_waymo.yml`](../ml3d/configs/pointpillars_waymo.yml).

ml3d/tf/modules/pointnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(self):
9393
def call(self, xyz, features=None, new_xyz=None, training=True):
9494
r"""
9595
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
96-
:param features: (B, N, C) tensor of the descriptors of the the features
96+
:param features: (B, N, C) tensor of the descriptors of the features
9797
:param new_xyz:
9898
:return:
9999
new_xyz: (B, npoint, 3) tensor of the new features' xyz

ml3d/tf/utils/pointnet/pointnet2_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self):
1717
def call(self, xyz, features=None, new_xyz=None, training=True):
1818
r"""
1919
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
20-
:param features: (B, N, C) tensor of the descriptors of the the features
20+
:param features: (B, N, C) tensor of the descriptors of the features
2121
:param new_xyz:
2222
:return:
2323
new_xyz: (B, npoint, 3) tensor of the new features' xyz

ml3d/torch/modules/pointnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def forward(self,
122122
new_xyz=None) -> (torch.Tensor, torch.Tensor):
123123
r"""
124124
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
125-
:param features: (B, N, C) tensor of the descriptors of the the features
125+
:param features: (B, N, C) tensor of the descriptors of the features
126126
:param new_xyz:
127127
:return:
128128
new_xyz: (B, npoint, 3) tensor of the new features' xyz

ml3d/torch/utils/pointnet/pointnet2_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def forward(self,
5050
r"""Forward.
5151
5252
:param xyz: (B, N, 3) tensor of the xyz coordinates of the features
53-
:param features: (B, N, C) tensor of the descriptors of the the features
53+
:param features: (B, N, C) tensor of the descriptors of the features
5454
:param new_xyz:
5555
:return:
5656
new_xyz: (B, npoint, 3) tensor of the new features' xyz

ml3d/utils/builder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,19 @@ def build_network(cfg):
1414
return build(cfg, NETWORK)
1515

1616

17-
def convert_device_name(framework, device_ids):
17+
def convert_device_name(device_type, device_ids):
1818
"""Convert device to either cpu or cuda."""
1919
gpu_names = ["gpu", "cuda"]
2020
cpu_names = ["cpu"]
21-
if framework not in cpu_names + gpu_names:
21+
if device_type not in cpu_names + gpu_names:
2222
raise KeyError("the device should either "
23-
"be cuda or cpu but got {}".format(framework))
23+
"be cuda or cpu but got {}".format(device_type))
2424
assert type(device_ids) is list
2525
device_ids_new = []
2626
for device in device_ids:
2727
device_ids_new.append(int(device))
2828

29-
if framework in gpu_names:
29+
if device_type in gpu_names:
3030
return "cuda", device_ids_new
3131
else:
3232
return "cpu", device_ids_new

scripts/run_pipeline.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
import numpy as np
1+
import os
22
import argparse
33
import logging
44
import sys
5-
import yaml
5+
from pathlib import Path
66
import pprint
7-
import os
7+
import yaml
8+
import numpy as np
89
import torch.distributed as dist
910
from torch import multiprocessing
1011

11-
from pathlib import Path
12+
import open3d.ml as _ml3d
1213

1314

1415
def parse_args():
@@ -42,7 +43,14 @@ def parse_args():
4243
parser.add_argument('--batch_size', help='batch size', default=None)
4344
parser.add_argument('--main_log_dir',
4445
help='the dir to save logs and models')
45-
parser.add_argument('--seed', help='random seed', default=0)
46+
parser.add_argument('--seed', help='random seed', default=0, type=int)
47+
parser.add_argument('--nodes', help='number of nodes', default=1, type=int)
48+
parser.add_argument('--node_rank',
49+
help='ranking within the nodes, default: 0. To get from'
50+
' the environment, enter the name of an env var eg: '
51+
'"SLURM_NODEID".',
52+
default="0",
53+
type=str)
4654
parser.add_argument(
4755
'--host',
4856
help='Host for distributed training, default: localhost',
@@ -57,6 +65,10 @@ def parse_args():
5765
default='gloo')
5866

5967
args, unknown = parser.parse_known_args()
68+
try:
69+
args.node_rank = int(args.node_rank)
70+
except ValueError: # str => get from environment
71+
args.node_rank = int(os.environ[args.node_rank])
6072

6173
parser_extra = argparse.ArgumentParser(description='Extra arguments')
6274
for arg in unknown:
@@ -73,9 +85,6 @@ def parse_args():
7385
return args, vars(args_extra)
7486

7587

76-
import open3d.ml as _ml3d
77-
78-
7988
def main():
8089
cmd_line = ' '.join(sys.argv[:])
8190
args, extra_dict = parse_args()
@@ -103,6 +112,10 @@ def main():
103112
if device == 'cpu':
104113
tf.config.set_visible_devices([], 'GPU')
105114
elif device == 'cuda':
115+
if len(args.device_ids) > 1:
116+
raise NotImplementedError(
117+
"Multi-GPU training with TensorFlow is not yet implemented."
118+
)
106119
tf.config.set_visible_devices(gpus[0], 'GPU')
107120
else:
108121
idx = device.split(':')[1]
@@ -152,8 +165,8 @@ def main():
152165
cfg_dict_model['seed'] = rng
153166
cfg_dict_pipeline['seed'] = rng
154167

155-
with open(Path(__file__).parent / 'README.md', 'r') as f:
156-
readme = f.read()
168+
with open(Path(__file__).parent / 'README.md', 'r') as freadme:
169+
readme = freadme.read()
157170

158171
cfg_tb = {
159172
'readme': readme,
@@ -197,9 +210,10 @@ def cleanup():
197210
dist.destroy_process_group()
198211

199212

200-
def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset,
213+
def main_worker(local_rank, Dataset, Model, Pipeline, cfg_dict_dataset,
201214
cfg_dict_model, cfg_dict_pipeline, args):
202-
world_size = len(args.device_ids)
215+
rank = args.node_rank * len(args.device_ids) + local_rank
216+
world_size = args.nodes * len(args.device_ids)
203217
setup(rank, world_size, args)
204218

205219
cfg_dict_dataset['rank'] = rank
@@ -211,8 +225,10 @@ def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset,
211225
cfg_dict_model['seed'] = rng
212226
cfg_dict_pipeline['seed'] = rng
213227

214-
device = f"cuda:{args.device_ids[rank]}"
215-
print(f"rank = {rank}, world_size = {world_size}, gpu = {device}")
228+
device = f"cuda:{args.device_ids[local_rank]}"
229+
print(
230+
f"local_rank = {local_rank}, rank = {rank}, world_size = {world_size},"
231+
f" gpu = {device}")
216232

217233
cfg_dict_model['device'] = device
218234
cfg_dict_pipeline['device'] = device

scripts/train_scripts/kpconv_kitti.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#SBATCH --gres=gpu:1
55

66
if [ "$#" -ne 2 ]; then
7-
echo "Please, provide the the training framework: torch/tf and dataset path"
7+
echo "Please, provide the training framework: torch/tf and dataset path"
88
exit 1
99
fi
1010

0 commit comments

Comments
 (0)