Skip to content

Commit 1fb2453

Browse files
authored
feat: add ssd detector (#704)
1 parent 94ef850 commit 1fb2453

15 files changed

+2580
-133
lines changed

examples/det/ssd/README.md

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SSD Based on MindCV Backbones
2+
3+
> [SSD: Single Shot MultiBox Detector](https://arxiv.org/abs/1512.02325)
4+
5+
## Introduction
6+
7+
SSD is an single-staged object detector. It discretizes the output space of bounding boxes into a set of default boxes over different aspect ratios and scales per feature map location, and combines predictions from multi-scale feature maps to detect objects with various sizes. At prediction time, SSD generates scores for the presence of each object category in each default box and produces adjustments to the box to better match the object shape.
8+
9+
<p align="center">
10+
<img src="https://github.com/DexterJZ/mindcv/assets/16130861/50bc9627-c71c-4b1a-9de4-9e6040a43279" width=800 />
11+
</p>
12+
<p align="center">
13+
<em>Figure 1. Architecture of SSD [<a href="#references">1</a>] </em>
14+
</p>
15+
16+
In this example, by leveraging [the multi-scale feature extraction of MindCV](https://github.com/mindspore-lab/mindcv/blob/main/docs/en/how_to_guides/feature_extraction.md), we demonstrate that using backbones from MindCV much simplifies the implementation of SSD.
17+
18+
## Configurations
19+
20+
Here, we provide three configurations of SSD.
21+
* Using [MobileNetV2](https://github.com/mindspore-lab/mindcv/tree/main/configs/mobilenetv2) as the backbone and the original detector described in the paper.
22+
* Using [ResNet50](https://github.com/mindspore-lab/mindcv/tree/main/configs/resnet) as the backbone with a FPN and a shared-weight-based detector.
23+
* Using [MobileNetV3](https://github.com/mindspore-lab/mindcv/tree/main/configs/mobilenetv3) as the backbone and the original detector described in the paper.
24+
25+
## Dataset
26+
27+
We train and test SSD using [COCO 2017 Dataset](https://cocodataset.org/#download). The dataset contains
28+
* 118000 images about 18 GB for training, and
29+
* 5000 images about 1 GB for testing.
30+
31+
## Quick Start
32+
33+
### Preparation
34+
35+
1. Clone MindCV repository by running
36+
```
37+
git clone https://github.com/mindspore-lab/mindcv.git
38+
```
39+
40+
2. Install dependencies as shown [here](https://mindspore-lab.github.io/mindcv/installation/).
41+
42+
3. Download [COCO 2017 Dataset](https://cocodataset.org/#download), prepare the dataset as follows.
43+
```
44+
.
45+
└─cocodataset
46+
├─annotations
47+
├─instance_train2017.json
48+
└─instance_val2017.json
49+
├─val2017
50+
└─train2017
51+
```
52+
Run the following commands to preprocess the dataset and convert it to [MindRecord format](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.mindrecord.html) for reducing preprocessing time during training and testing.
53+
```
54+
cd mindcv # change directory to the root of MindCV repository
55+
python examples/det/ssd/create_data.py coco --data_path [root of COCO 2017 Dataset] --out_path [directory for storing MindRecord files]
56+
```
57+
Specify the path of the preprocessed dataset at keyword `data_dir` in the config file.
58+
59+
4. Download the pretrained backbone weights from the table below, and specify the path to the backbone weights at keyword `backbone_ckpt_path` in the config file.
60+
<div align="center">
61+
62+
| MobileNetV2 | ResNet50 | MobileNetV3 |
63+
|:----------------:|:----------------:|:----------------:|
64+
| [backbone weights](https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv2/mobilenet_v2_100-d5532038.ckpt) | [backbone weights](https://download.mindspore.cn/toolkits/mindcv/resnet/resnet50-e0733ab8.ckpt) | [backbone weights](https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv3/mobilenet_v3_large_100-1279ad5f.ckpt) |
65+
66+
</div>
67+
68+
### Train
69+
70+
It is highly recommended to use **distributed training** for this SSD implementation.
71+
72+
For distributed training using **OpenMPI's `mpirun`**, simply run
73+
```
74+
cd mindcv # change directory to the root of MindCV repository
75+
mpirun -n [# of devices] python examples/det/ssd/train.py --config [the path to the config file]
76+
```
77+
For example, if train SSD distributively with the `MobileNetV2` configuration on 8 devices, run
78+
```
79+
cd mindcv # change directory to the root of MindCV repository
80+
mpirun -n 8 python examples/det/ssd/train.py --config examples/det/ssd/ssd_mobilenetv2.yaml
81+
```
82+
83+
For distributed training with [Ascend rank table](https://github.com/mindspore-lab/mindocr/blob/main/docs/en/tutorials/distribute_train.md#12-configure-rank_table_file-for-training), configure `ascend8p.sh` as follows
84+
```
85+
#!/bin/bash
86+
export DEVICE_NUM=8
87+
export RANK_SIZE=8
88+
export RANK_TABLE_FILE="./hccl_8p_01234567_127.0.0.1.json"
89+
90+
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
91+
export DEVICE_ID=$i
92+
export RANK_ID=$i
93+
echo "Launching rank: ${RANK_ID}, device: ${DEVICE_ID}"
94+
if [ $i -eq 0 ]; then
95+
echo 'i am 0'
96+
python examples/det/ssd/train.py --config [the path to the config file] &> ./train.log &
97+
else
98+
echo 'not 0'
99+
python -u examples/det/ssd/train.py --config [the path to the config file] &> /dev/null &
100+
fi
101+
done
102+
```
103+
and start training by running
104+
```
105+
cd mindcv # change directory to the root of MindCV repository
106+
bash ascend8p.sh
107+
```
108+
109+
For single-device training, please run
110+
```
111+
cd mindcv # change directory to the root of MindCV repository
112+
python examples/det/ssd/train.py --config [the path to the config file]
113+
```
114+
115+
### Test
116+
117+
For testing the trained model, first specify the path to the model checkpoint at keyword `ckpt_path` in the config file, then run
118+
```
119+
cd mindcv # change directory to the root of MindCV repository
120+
python examples/det/ssd/eval.py --config [the path to the config file]
121+
```
122+
For example, for testing SSD with the `MobileNetV2` configuration, run
123+
```
124+
cd mindcv # change directory to the root of MindCV repository
125+
python examples/det/ssd/eval.py --config examples/det/ssd/ssd_mobilenetv2.yaml
126+
```
127+
128+
## Performance
129+
130+
Here are the performance resutls and the pretrained model weights for each configuration.
131+
<div align="center">
132+
133+
| Configuration | Mixed Precision | mAP | Config | Download |
134+
|:-----------------:|:---------------:|:----:|:------:|:--------:|
135+
| MobileNetV2 | O2 | 23.2 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/det/ssd/ssd_mobilenetv2.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/ssd/ssd_mobilenetv2-5bbd7411.ckpt) |
136+
| ResNet50 with FPN | O3 | 38.3 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/det/ssd/ssd_resnet50_fpn.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/ssd/ssd_resnet50_fpn-ac87ddac.ckpt) |
137+
| MobileNetV3 | O2 | 23.8 | [yaml](https://github.com/mindspore-lab/mindcv/blob/main/examples/det/ssd/ssd_mobilenetv3.yaml) | [weights](https://download.mindspore.cn/toolkits/mindcv/ssd/ssd_mobilenetv3-53d9f6e9.ckpt) |
138+
139+
</div>
140+
141+
## References
142+
143+
[1] Liu, W., Anguelov, D., Erhan, D., Szegedy, C., Reed, S., Fu, C. Y., & Berg, A. C. (2016). SSD: Single Shot Multibox Detector. In Computer Vision–ECCV 2016: 14th European Conference, Amsterdam, The Netherlands, October 11–14, 2016, Proceedings, Part I 14 (pp. 21-37). Springer International Publishing.

examples/det/ssd/callbacks.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import os
2+
import stat
3+
4+
from utils import apply_eval
5+
6+
from mindspore import log as logger
7+
from mindspore import save_checkpoint
8+
from mindspore.train.callback import Callback, CheckpointConfig, LossMonitor, ModelCheckpoint, TimeMonitor
9+
10+
11+
class EvalCallBack(Callback):
12+
"""
13+
Evaluation callback when training.
14+
15+
Args:
16+
eval_function (function): evaluation function.
17+
eval_param_dict (dict): evaluation parameters' configure dict.
18+
interval (int): run evaluation interval, default is 1.
19+
eval_start_epoch (int): evaluation start epoch, default is 1.
20+
save_best_ckpt (bool): Whether to save best checkpoint, default is True.
21+
best_ckpt_name (str): best checkpoint name, default is `best.ckpt`.
22+
metrics_name (str): evaluation metrics name, default is `acc`.
23+
24+
Returns:
25+
None
26+
27+
Examples:
28+
>>> EvalCallBack(eval_function, eval_param_dict)
29+
"""
30+
31+
def __init__(
32+
self,
33+
eval_function,
34+
eval_param_dict,
35+
interval=1,
36+
eval_start_epoch=1,
37+
save_best_ckpt=True,
38+
ckpt_directory="./",
39+
best_ckpt_name="best.ckpt",
40+
metrics_name="acc",
41+
):
42+
super(EvalCallBack, self).__init__()
43+
self.eval_function = eval_function
44+
self.eval_param_dict = eval_param_dict
45+
self.eval_start_epoch = eval_start_epoch
46+
47+
if interval < 1:
48+
raise ValueError("interval should >= 1.")
49+
50+
self.interval = interval
51+
self.save_best_ckpt = save_best_ckpt
52+
self.best_res = 0
53+
self.best_epoch = 0
54+
55+
if not os.path.isdir(ckpt_directory):
56+
os.makedirs(ckpt_directory)
57+
58+
self.best_ckpt_path = os.path.join(ckpt_directory, best_ckpt_name)
59+
self.metrics_name = metrics_name
60+
61+
def remove_ckpoint_file(self, file_name):
62+
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
63+
try:
64+
os.chmod(file_name, stat.S_IWRITE)
65+
os.remove(file_name)
66+
except OSError:
67+
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
68+
except ValueError:
69+
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)
70+
71+
def on_train_epoch_end(self, run_context):
72+
"""Callback when epoch end."""
73+
cb_params = run_context.original_args()
74+
cur_epoch = cb_params.cur_epoch_num
75+
76+
if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0:
77+
res = self.eval_function(self.eval_param_dict)
78+
print("epoch: {}, {}: {}".format(cur_epoch, self.metrics_name, res), flush=True)
79+
80+
if res >= self.best_res:
81+
self.best_res = res
82+
self.best_epoch = cur_epoch
83+
print("update best result: {}".format(res), flush=True)
84+
85+
if self.save_best_ckpt:
86+
if os.path.exists(self.best_ckpt_path):
87+
self.remove_ckpoint_file(self.best_ckpt_path)
88+
89+
save_checkpoint(cb_params.train_network, self.best_ckpt_path)
90+
print("update best checkpoint at: {}".format(self.best_ckpt_path), flush=True)
91+
92+
def on_train_end(self, run_context):
93+
print(
94+
"End training, the best {0} is: {1}, the best {0} epoch is {2}".format(
95+
self.metrics_name, self.best_res, self.best_epoch
96+
),
97+
flush=True,
98+
)
99+
100+
101+
def get_ssd_callbacks(args, steps_per_epoch, rank_id):
102+
ckpt_config = CheckpointConfig(keep_checkpoint_max=args.keep_checkpoint_max)
103+
ckpt_cb = ModelCheckpoint(prefix="ssd", directory=args.ckpt_save_dir, config=ckpt_config)
104+
105+
if rank_id == 0:
106+
return [TimeMonitor(data_size=steps_per_epoch), LossMonitor(), ckpt_cb]
107+
108+
return [TimeMonitor(data_size=steps_per_epoch), LossMonitor()]
109+
110+
111+
def get_ssd_eval_callback(eval_net, eval_dataset, args):
112+
if args.dataset == "coco":
113+
anno_json = os.path.join(args.data_dir, "annotations/instances_val2017.json")
114+
else:
115+
raise NotImplementedError
116+
117+
eval_param_dict = {"net": eval_net, "dataset": eval_dataset, "anno_json": anno_json, "args": args}
118+
119+
eval_cb = EvalCallBack(
120+
apply_eval,
121+
eval_param_dict,
122+
interval=args.eval_interval,
123+
eval_start_epoch=args.eval_start_epoch,
124+
save_best_ckpt=True,
125+
ckpt_directory=args.ckpt_save_dir,
126+
best_ckpt_name="best.ckpt",
127+
metrics_name="mAP",
128+
)
129+
130+
return eval_cb

0 commit comments

Comments
 (0)