Skip to content

Commit 9f763ed

Browse files
committed
Add ResNet-18 and ResNet-50 backbones, and add VOC 2007 Cat Dog dataset
1 parent eb11069 commit 9f763ed

File tree

5 files changed

+274
-29
lines changed

5 files changed

+274
-29
lines changed

README.md

Lines changed: 175 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ An easy implementation of Faster R-CNN in PyTorch.
2222

2323
* PASCAL VOC 2007
2424

25-
* Train: 2007 trainval (5011 samples)
26-
* Eval: 2007 test (4952 samples)
25+
* Train: 2007 trainval (5011 images)
26+
* Eval: 2007 test (4952 images)
2727

2828
<table>
2929
<tr>
@@ -168,6 +168,52 @@ An easy implementation of Faster R-CNN in PyTorch.
168168
<td>0.1</td>
169169
<td>70000</td>
170170
</tr>
171+
<tr>
172+
<td>Ours</td>
173+
<td>ResNet-18</td>
174+
<td>GTX 1080 Ti</td>
175+
<td>~ 19.4</td>
176+
<td>~ 38.7</td>
177+
<td>0.6783</td>
178+
<td>600</td>
179+
<td>1000</td>
180+
<td>[(1, 2), (1, 1), (2, 1)]</td>
181+
<td>[128, 256, 512]</td>
182+
<td>align</td>
183+
<td>12000</td>
184+
<td>2000</td>
185+
<td>6000</td>
186+
<td>300</td>
187+
<td>0.001</td>
188+
<td>0.9</td>
189+
<td>0.0005</td>
190+
<td>50000</td>
191+
<td>0.1</td>
192+
<td>70000</td>
193+
</tr>
194+
<tr>
195+
<td>Ours</td>
196+
<td>ResNet-50</td>
197+
<td>GTX 1080 Ti</td>
198+
<td>~ 8.7</td>
199+
<td>~ 22.4</td>
200+
<td>0.7402</td>
201+
<td>600</td>
202+
<td>1000</td>
203+
<td>[(1, 2), (1, 1), (2, 1)]</td>
204+
<td>[128, 256, 512]</td>
205+
<td>align</td>
206+
<td>12000</td>
207+
<td>2000</td>
208+
<td>6000</td>
209+
<td>300</td>
210+
<td>0.001</td>
211+
<td>0.9</td>
212+
<td>0.0005</td>
213+
<td>50000</td>
214+
<td>0.1</td>
215+
<td>70000</td>
216+
</tr>
171217
<tr>
172218
<td>ruotianluo/pytorch-faster-rcnn</td>
173219
<td>ResNet-101</td>
@@ -222,7 +268,7 @@ An easy implementation of Faster R-CNN in PyTorch.
222268
</td>
223269
<td>ResNet-101</td>
224270
<td>GTX 1080 Ti</td>
225-
<td>~ 6.3</td>
271+
<td>5 ~ 6</td>
226272
<td>~ 11.8</td>
227273
<td>0.7538</td>
228274
<td>600</td>
@@ -247,8 +293,8 @@ An easy implementation of Faster R-CNN in PyTorch.
247293
248294
* MS COCO 2017
249295

250-
* Train: 2017 Train = 2015 Train + 2015 Val - 2015 Val Sample 5k (117266 samples)
251-
* Eval: 2017 Val = 2015 Val Sample 5k (formerly known as `minival`) (4952 samples)
296+
* Train: 2017 Train = 2015 Train + 2015 Val - 2015 Val Sample 5k (117266 images)
297+
* Eval: 2017 Val = 2015 Val Sample 5k (formerly known as `minival`) (4952 images)
252298

253299
<table>
254300
<tr>
@@ -331,21 +377,21 @@ An easy implementation of Faster R-CNN in PyTorch.
331377
<td>~ 5.1</td>
332378
<td>~ 8.9</td>
333379
<td>0.287</td>
334-
<td>800</td>
335-
<td>1333</td>
380+
<td><b>800</b></td>
381+
<td><b>1333</b></td>
336382
<td>[(1, 2), (1, 1), (2, 1)]</td>
337-
<td>[64, 128, 256, 512]</td>
383+
<td><b>[64, 128, 256, 512]</b></td>
338384
<td>align</td>
339385
<td>12000</td>
340386
<td>2000</td>
341387
<td>6000</td>
342-
<td>1000</td>
388+
<td><b>1000</b></td>
343389
<td>0.001</td>
344390
<td>0.9</td>
345-
<td>0.0001</td>
346-
<td>900000</td>
391+
<td><b>0.0001</b></td>
392+
<td><b>900000</b></td>
347393
<td>0.1</td>
348-
<td>1200000</td>
394+
<td><b>1200000</b></td>
349395
</tr>
350396
<tr>
351397
<td>ruotianluo/pytorch-faster-rcnn</td>
@@ -404,21 +450,21 @@ An easy implementation of Faster R-CNN in PyTorch.
404450
<td>~ 4.7</td>
405451
<td>~ 7.8</td>
406452
<td>0.352</td>
407-
<td>800</td>
408-
<td>1333</td>
453+
<td><b>800</b></td>
454+
<td><b>1333</b></td>
409455
<td>[(1, 2), (1, 1), (2, 1)]</td>
410-
<td>[64, 128, 256, 512]</td>
456+
<td><b>[64, 128, 256, 512]</b></td>
411457
<td>align</td>
412458
<td>12000</td>
413459
<td>2000</td>
414460
<td>6000</td>
415-
<td>1000</td>
461+
<td><b>1000</b></td>
416462
<td>0.001</td>
417463
<td>0.9</td>
418-
<td>0.0001</td>
419-
<td>900000</td>
464+
<td><b>0.0001</b></td>
465+
<td><b>900000</b></td>
420466
<td>0.1</td>
421-
<td>1200000</td>
467+
<td><b>1200000</b></td>
422468
</tr>
423469
<tr>
424470
<td>
@@ -431,26 +477,128 @@ An easy implementation of Faster R-CNN in PyTorch.
431477
<td>~ 4.5</td>
432478
<td>~ 7.5</td>
433479
<td>0.358</td>
434-
<td>800</td>
435-
<td>1333</td>
480+
<td><b>800</b></td>
481+
<td><b>1333</b></td>
436482
<td>[(1, 2), (1, 1), (2, 1)]</td>
437-
<td>[32, 64, 128, 256, 512]</td>
483+
<td><b>[32, 64, 128, 256, 512]</b></td>
438484
<td>align</td>
439485
<td>12000</td>
440486
<td>2000</td>
441487
<td>6000</td>
442-
<td>1000</td>
488+
<td><b>1000</b></td>
443489
<td>0.001</td>
444490
<td>0.9</td>
445-
<td>0.0001</td>
446-
<td>900000</td>
491+
<td><b>0.0001</b></td>
492+
<td><b>900000</b></td>
447493
<td>0.1</td>
448-
<td>1200000</td>
494+
<td><b>1200000</b></td>
449495
</tr>
450496
</table>
451497

452498
> Scroll to right for more configurations
453-
499+
500+
* PASCAL VOC 2007 Cat Dog
501+
502+
* Train: 2007 trainval drops categories other than cat and dog (750 images)
503+
* Eval: 2007 test drops categories other than cat and dog (728 images)
504+
505+
<table>
506+
<tr>
507+
<th>Implementation</th>
508+
<th>Backbone</th>
509+
<th>GPU</th>
510+
<th>Training Speed (FPS)</th>
511+
<th>Inference Speed (FPS)</th>
512+
<th>mAP</th>
513+
<th>image_min_side</th>
514+
<th>image_max_side</th>
515+
<th>anchor_ratios</th>
516+
<th>anchor_sizes</th>
517+
<th>pooling_mode</th>
518+
<th>train_pre_rpn_nms_top_n</th>
519+
<th>train_post_rpn_nms_top_n</th>
520+
<th>eval_pre_rpn_nms_top_n</th>
521+
<th>eval_post_rpn_nms_top_n</th>
522+
<th>learning_rate</th>
523+
<th>momentum</th>
524+
<th>weight_decay</th>
525+
<th>step_lr_size</th>
526+
<th>step_lr_gamma</th>
527+
<th>num_steps_to_finish</th>
528+
</tr>
529+
<tr>
530+
<td>Ours</td>
531+
<td>ResNet-18</td>
532+
<td>GTX 1080 Ti</td>
533+
<td>~ 19.4</td>
534+
<td>~ 56.2</td>
535+
<td>0.3776</td>
536+
<td>600</td>
537+
<td>1000</td>
538+
<td>[(1, 2), (1, 1), (2, 1)]</td>
539+
<td>[128, 256, 512]</td>
540+
<td>align</td>
541+
<td>12000</td>
542+
<td>2000</td>
543+
<td>6000</td>
544+
<td>300</td>
545+
<td>0.001</td>
546+
<td>0.9</td>
547+
<td>0.0005</td>
548+
<td><b>700</b></td>
549+
<td>0.1</td>
550+
<td><b>1000</b></td>
551+
</tr>
552+
<tr>
553+
<td>Ours</td>
554+
<td>ResNet-18</td>
555+
<td>GTX 1080 Ti</td>
556+
<td>~ 19.4</td>
557+
<td>~ 56.2</td>
558+
<td>0.6175</td>
559+
<td>600</td>
560+
<td>1000</td>
561+
<td>[(1, 2), (1, 1), (2, 1)]</td>
562+
<td>[128, 256, 512]</td>
563+
<td>align</td>
564+
<td>12000</td>
565+
<td>2000</td>
566+
<td>6000</td>
567+
<td>300</td>
568+
<td>0.001</td>
569+
<td>0.9</td>
570+
<td>0.0005</td>
571+
<td><b>2000</b></td>
572+
<td>0.1</td>
573+
<td><b>3000</b></td>
574+
</tr>
575+
<tr>
576+
<td>Ours</td>
577+
<td>ResNet-18</td>
578+
<td>GTX 1080 Ti</td>
579+
<td>~ 19.4</td>
580+
<td>~ 56.2</td>
581+
<td>0.7639</td>
582+
<td>600</td>
583+
<td>1000</td>
584+
<td>[(1, 2), (1, 1), (2, 1)]</td>
585+
<td>[128, 256, 512]</td>
586+
<td>align</td>
587+
<td>12000</td>
588+
<td>2000</td>
589+
<td>6000</td>
590+
<td>300</td>
591+
<td>0.001</td>
592+
<td>0.9</td>
593+
<td>0.0005</td>
594+
<td><b>7000</b></td>
595+
<td>0.1</td>
596+
<td><b>10000</b></td>
597+
</tr>
598+
</table>
599+
600+
> Scroll to right for more configurations
601+
454602

455603
## Requirements
456604

backbone/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,19 @@
55

66
class Base(object):
77

8-
OPTIONS = ['vgg16', 'resnet101']
8+
OPTIONS = ['vgg16', 'resnet18', 'resnet50', 'resnet101']
99

1010
@staticmethod
1111
def from_name(name: str) -> Type['Base']:
1212
if name == 'vgg16':
1313
from backbone.vgg16 import Vgg16
1414
return Vgg16
15+
elif name == 'resnet18':
16+
from backbone.resnet18 import ResNet18
17+
return ResNet18
18+
elif name == 'resnet50':
19+
from backbone.resnet50 import ResNet50
20+
return ResNet50
1521
elif name == 'resnet101':
1622
from backbone.resnet101 import ResNet101
1723
return ResNet101

backbone/resnet18.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Tuple, Callable
2+
3+
import torchvision
4+
from torch import nn, Tensor
5+
from torch.nn import functional as F
6+
7+
import backbone.base
8+
9+
10+
class ResNet18(backbone.base.Base):
11+
12+
def __init__(self, pretrained: bool):
13+
super().__init__(pretrained)
14+
15+
def features(self) -> Tuple[nn.Module, Callable[[Tensor], Tensor], nn.Module, Callable[[Tensor], Tensor], int, int]:
16+
resnet18 = torchvision.models.resnet18(pretrained=self._pretrained)
17+
18+
# list(resnet18.children()) consists of following modules
19+
# [0] = Conv2d, [1] = BatchNorm2d, [2] = ReLU, [3] = MaxPool2d,
20+
# [4] = Sequential(Bottleneck...), [5] = Sequential(Bottleneck...),
21+
# [6] = Sequential(Bottleneck...), [7] = Sequential(Bottleneck...),
22+
# [8] = AvgPool2d, [9] = Linear
23+
children = list(resnet18.children())
24+
features = children[:-3]
25+
num_features_out = 256
26+
27+
hidden = children[-3]
28+
num_hidden_out = 512
29+
30+
for parameters in [feature.parameters() for i, feature in enumerate(features) if i <= 4]:
31+
for parameter in parameters:
32+
parameter.requires_grad = False
33+
34+
features = nn.Sequential(*features)
35+
36+
return features, self.pool_handler, hidden, self.hidden_handler, num_features_out, num_hidden_out
37+
38+
def pool_handler(self, pool: Tensor) -> Tensor:
39+
return pool
40+
41+
def hidden_handler(self, hidden: Tensor) -> Tensor:
42+
hidden = F.adaptive_max_pool2d(input=hidden, output_size=1)
43+
hidden = hidden.view(hidden.shape[0], -1)
44+
return hidden

backbone/resnet50.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from typing import Tuple, Callable
2+
3+
import torchvision
4+
from torch import nn, Tensor
5+
from torch.nn import functional as F
6+
7+
import backbone.base
8+
9+
10+
class ResNet50(backbone.base.Base):
11+
12+
def __init__(self, pretrained: bool):
13+
super().__init__(pretrained)
14+
15+
def features(self) -> Tuple[nn.Module, Callable[[Tensor], Tensor], nn.Module, Callable[[Tensor], Tensor], int, int]:
16+
resnet50 = torchvision.models.resnet50(pretrained=self._pretrained)
17+
18+
# list(resnet50.children()) consists of following modules
19+
# [0] = Conv2d, [1] = BatchNorm2d, [2] = ReLU, [3] = MaxPool2d,
20+
# [4] = Sequential(Bottleneck...), [5] = Sequential(Bottleneck...),
21+
# [6] = Sequential(Bottleneck...), [7] = Sequential(Bottleneck...),
22+
# [8] = AvgPool2d, [9] = Linear
23+
children = list(resnet50.children())
24+
features = children[:-3]
25+
num_features_out = 1024
26+
27+
hidden = children[-3]
28+
num_hidden_out = 2048
29+
30+
for parameters in [feature.parameters() for i, feature in enumerate(features) if i <= 4]:
31+
for parameter in parameters:
32+
parameter.requires_grad = False
33+
34+
features = nn.Sequential(*features)
35+
36+
return features, self.pool_handler, hidden, self.hidden_handler, num_features_out, num_hidden_out
37+
38+
def pool_handler(self, pool: Tensor) -> Tensor:
39+
return pool
40+
41+
def hidden_handler(self, hidden: Tensor) -> Tensor:
42+
hidden = F.adaptive_max_pool2d(input=hidden, output_size=1)
43+
hidden = hidden.view(hidden.shape[0], -1)
44+
return hidden

0 commit comments

Comments
 (0)