Skip to content

Commit 93952cd

Browse files
committed
optimize test_transform
1 parent 55694fe commit 93952cd

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

change_detection_pytorch/datasets/LEVIR_CD.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import albumentations as A
44
from albumentations.pytorch import ToTensorV2
5-
from change_detection_pytorch.datasets.custom import CustomDataset
5+
6+
from .custom import CustomDataset
7+
from .transforms.albu import ChunkImage, ToTensorTest
68

79

810
class LEVIR_CD_Dataset(CustomDataset):
@@ -27,12 +29,9 @@ def get_default_transform(self):
2729
def get_test_transform(self):
2830
"""Set the test transformation."""
2931

30-
from change_detection_pytorch.datasets.transforms.albu import (
31-
ChunkImage, ToTensorTest)
3232
test_transform = A.Compose([
3333
A.Normalize(),
34-
ChunkImage(self.size),
35-
ToTensorTest(),
34+
ToTensorV2()
3635
], additional_targets={'image_2': 'image'})
3736
return test_transform
3837

change_detection_pytorch/datasets/SVCD.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def get_test_transform(self):
2929
"""Set the test transformation."""
3030

3131
test_transform = A.Compose([
32-
A.Resize(self.size, self.size),
3332
A.Normalize(),
3433
ToTensorV2()
3534
], additional_targets={'image_2': 'image'})

local_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
encoder_weights="imagenet", # use `imagenet` pre-trained weights for encoder initialization
1313
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
1414
classes=2, # model output channels (number of classes in your datasets)
15-
siam_encoder=True,
16-
fusion_form='concat',
15+
siam_encoder=True, # whether to use a siamese encoder
16+
fusion_form='concat', # the form of fusing features from two branches. e.g. concat, sum, diff, or abs_diff.
1717
)
1818

1919
train_dataset = LEVIR_CD_Dataset('../LEVIR-CD/train',
@@ -86,5 +86,4 @@
8686
print('Model saved!')
8787

8888
# save results (change maps)
89-
valid_epoch.infer_vis(valid_loader, slide=True, image_size=1024, window_size=256,
90-
save_dir='./res')
89+
valid_epoch.infer_vis(valid_loader, save=True, slide=False, save_dir='./res')

0 commit comments

Comments
 (0)