Skip to content

Commit 57816ff

Browse files
committed
Added check_dataset_output
1 parent 6077c7b commit 57816ff

File tree

5 files changed

+28
-3
lines changed

5 files changed

+28
-3
lines changed

wildlife_tools/features/deep.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from transformers import CLIPModel, CLIPProcessor
44
from typing import Optional
55
from ..data import FeatureDataset, ImageDataset
6+
from ..tools import check_dataset_output
67

78

89
class DeepFeatures:
@@ -43,7 +44,8 @@ def __call__(self, dataset: ImageDataset) -> FeatureDataset:
4344

4445
self.model = self.model.to(self.device)
4546
self.model = self.model.eval()
46-
47+
48+
check_dataset_output(dataset, check_label=False)
4749
loader = torch.utils.data.DataLoader(
4850
dataset,
4951
num_workers=self.num_workers,
@@ -119,6 +121,7 @@ def __call__(self, dataset: ImageDataset) -> FeatureDataset:
119121
# TODO: this is hacky
120122
dataset.transforms = None # Reset transforms.
121123

124+
check_dataset_output(dataset, check_label=False)
122125
loader = torch.utils.data.DataLoader(
123126
dataset,
124127
num_workers=self.num_workers,
@@ -127,7 +130,7 @@ def __call__(self, dataset: ImageDataset) -> FeatureDataset:
127130
collate_fn=lambda x: x,
128131
)
129132
outputs = []
130-
for image in tqdm(loader, mininterval=1, ncols=100):
133+
for image, _ in tqdm(loader, mininterval=1, ncols=100):
131134
with torch.no_grad():
132135
output = self.model(self.transform(image).to(self.device)).pooler_output
133136
outputs.append(output.cpu())

wildlife_tools/features/local.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tqdm import tqdm
77

88
from ..data import FeatureDataset, ImageDataset
9+
from ..tools import check_dataset_output
910
from .gluefactory_fix import extract_single_image_fix # https://github.com/cvg/glue-factory/pull/50
1011

1112

@@ -53,6 +54,7 @@ def __call__(self, dataset: ImageDataset) -> FeatureDataset:
5354
"""
5455

5556
features = []
57+
check_dataset_output(dataset, check_label=False)
5658
loader = torch.utils.data.DataLoader(
5759
dataset,
5860
num_workers=self.num_workers,

wildlife_tools/features/memory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from tqdm import tqdm
2-
32
from ..data import FeatureDataset, ImageDataset
3+
from ..tools import check_dataset_output
44

55

66
class DataToMemory:
@@ -17,6 +17,7 @@ def __call__(self, dataset: ImageDataset):
1717
"""Loads data from input dataset into array and returns them as a new FeatureDataset."""
1818

1919
features = []
20+
check_dataset_output(dataset, check_label=False)
2021
for x, _ in tqdm(dataset, mininterval=1, ncols=100):
2122
features.append(x)
2223
return FeatureDataset(

wildlife_tools/tools.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,23 @@
88
from PIL import Image
99

1010

11+
def check_dataset_output(dataset, check_label=False):
12+
output = dataset[0]
13+
if not isinstance(output, tuple):
14+
raise ValueError('''
15+
Calling `dataset[0]` must returned a tuple.
16+
Try to use `load_label=True` when creating the dataset.
17+
''')
18+
label = output[1]
19+
if check_label and (isinstance(label, str) or isinstance(label, np.str_)):
20+
raise ValueError('''
21+
Calling `dataset[0]` must returned a tuple,
22+
where the second part (label) is an integer.
23+
If you used the WildlifeDataset from wildlife-datasets,
24+
try to use `factorize_label=True` when creating the dataset.
25+
''')
26+
27+
1128
def frame_image(img, frame_width, color=(255, 0, 0)):
1229
b = frame_width
1330
ny, nx = img.shape[0], img.shape[1]

wildlife_tools/train/trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from tqdm import tqdm
77
from typing import Optional, Callable
88
from ..data import ImageDataset
9+
from ..tools import check_dataset_output
910

1011

1112
def set_seed(seed=0, device="cuda"):
@@ -94,6 +95,7 @@ def __init__(
9495
accumulation_steps: int = 1,
9596
epoch_callback: Optional[Callable] = None,
9697
):
98+
check_dataset_output(dataset, check_label=True)
9799
self.dataset = dataset
98100
self.model = model.to(device)
99101
self.objective = objective.to(device)

0 commit comments

Comments
 (0)