Skip to content

Commit 7f501b8

Browse files
committed
add STANet
1 parent a42de66 commit 7f501b8

File tree

4 files changed

+61
-14
lines changed

4 files changed

+61
-14
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ Please refer to local_test.py temporarily.
4949

5050
- [x] UPerNet [[paper](https://arxiv.org/abs/1807.10221)]
5151

52+
- [x] STANet [[paper](https://www.mdpi.com/2072-4292/12/10/1662)]
53+
5254
#### Encoders <a name="encoders"></a>
5355

5456
The following is a list of supported encoders in the CDP. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights (`encoder_name` and `encoder_weights` parameters).

change_detection_pytorch/stanet/decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ def __init__(
5050
sa_mode='PAM'
5151
):
5252
super(STANetDecoder, self).__init__()
53+
self.out_channel = f_c
5354
self.backbone_decoder = BackboneDecoder(f_c, nn.BatchNorm2d, encoder_out_channels)
54-
self.netA = CDSA(in_c=64, ds=1, mode=sa_mode)
55+
self.netA = CDSA(in_c=f_c, ds=1, mode=sa_mode)
5556

5657
def forward(self, *features):
5758
# fetch feature maps
@@ -149,7 +150,6 @@ def forward(self, input):
149150
x = self.relu(x)
150151
return x
151152

152-
153153
# if __name__ == '__main__':
154154
# from change_detection_pytorch.encoders import get_encoder
155155
#

change_detection_pytorch/stanet/model.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,42 @@
33
from torch.nn import functional as F
44
from ..encoders import get_encoder
55
from .decoder import STANetDecoder
6+
from ..base import SegmentationHead
67

78

89
class STANet(torch.nn.Module):
10+
"""
11+
Args:
12+
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
13+
to extract features of different spatial resolution
14+
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
15+
other pretrained weights (see table with available weights for each encoder_name)
16+
in_channels: A number of input channels for the model, default is 3 (RGB images)
17+
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
18+
activation: An activation function to apply after the final convolution layer.
19+
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**, **callable** and **None**.
20+
Default is **None**
21+
return_distance_map: If True, return distance map, which shape is (BatchSize, Height, Width), of feature maps from images of two periods. Default False.
22+
23+
Returns:
24+
``torch.nn.Module``: STANet
25+
26+
.. STANet:
27+
https://www.mdpi.com/2072-4292/12/10/1662
28+
29+
"""
930
def __init__(
1031
self,
1132
encoder_name: str = "resnet",
1233
encoder_weights: Optional[str] = "imagenet",
1334
sa_mode: str = "PAM",
1435
in_channels: int = 3,
36+
classes=2,
37+
activation=None,
38+
return_distance_map=False
1539
):
1640
super(STANet, self).__init__()
41+
self.return_distance_map = return_distance_map
1742
self.encoder = get_encoder(
1843
encoder_name,
1944
in_channels=in_channels,
@@ -24,11 +49,23 @@ def __init__(
2449
encoder_out_channels=self.encoder.out_channels,
2550
sa_mode=sa_mode
2651
)
52+
self.segmentation_head = SegmentationHead(
53+
in_channels=self.decoder.out_channel * 2,
54+
out_channels=classes,
55+
activation=activation,
56+
kernel_size=3,
57+
)
2758

2859
def forward(self, x1, x2):
2960
# only support siam encoder
3061
features = self.encoder(x1), self.encoder(x2)
3162
features = self.decoder(*features)
32-
dist = F.pairwise_distance(features[0], features[1],keepdim=True)
33-
dist = F.interpolate(dist, x1.shape[2:], mode='bilinear', align_corners=True)
34-
return dist
63+
if self.return_distance_map:
64+
dist = F.pairwise_distance(features[0], features[1], keepdim=True)
65+
dist = F.interpolate(dist, x1.shape[2:], mode='bilinear', align_corners=True)
66+
return dist
67+
else:
68+
decoder_output = torch.cat([features[0], features[1]], dim=1)
69+
decoder_output = F.interpolate(decoder_output, x1.shape[2:], mode='bilinear', align_corners=True)
70+
masks = self.segmentation_head(decoder_output)
71+
return masks

lino_test.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
import torch
2-
from change_detection_pytorch.stanet import STANet
2+
from torch.utils.data import DataLoader, Dataset
33

4-
if __name__ == '__main__':
4+
import change_detection_pytorch as cdp
5+
from change_detection_pytorch.datasets import LEVIR_CD_Dataset, SVCD_Dataset
6+
from change_detection_pytorch.utils.lr_scheduler import GradualWarmupScheduler
57

6-
samples = torch.ones([1, 3, 256, 256])
7-
model = STANet(
8-
encoder_name='vgg16',
9-
in_channels=3
10-
)
11-
dist = model(samples, samples)
12-
print(dist.size())
8+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
9+
10+
model = cdp.STANet(
11+
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
12+
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
13+
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
14+
return_distance_map=False
15+
).to(DEVICE)
16+
17+
sampel1 = torch.ones([1, 3, 256, 256]).to(DEVICE)
18+
sampel2 = torch.ones([1, 3, 256, 256]).to(DEVICE)
19+
preds = model(sampel1, sampel2)
20+
print(preds.size())

0 commit comments

Comments
 (0)