Skip to content

Commit 1b6abaa

Browse files
committed
dataset: Implement color jitter augmentation
1 parent 60f17a2 commit 1b6abaa

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

utils/datasets.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,16 @@ def __getitem__(self, index):
8787

8888
img_path = self.img_files[index % len(self.img_files)].rstrip()
8989

90+
if self.augment:
91+
transforms = torchvision.transforms.Compose([
92+
torchvision.transforms.ColorJitter(brightness=1.5, saturation=1.5, hue=0.1),
93+
torchvision.transforms.ToTensor()
94+
])
95+
else:
96+
transforms = torchvision.transforms.ToTensor()
97+
9098
# Extract image as PyTorch tensor
91-
img = torchvision.transforms.ToTensor()(Image.open(img_path).convert('RGB'))
99+
img = transforms(Image.open(img_path).convert('RGB'))
92100

93101
_, h, w = img.shape
94102
h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)

0 commit comments

Comments
 (0)