Skip to content

Commit aa0ea7d

Browse files
committed
Added save_features
1 parent 96b4b5e commit aa0ea7d

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

wildlife_tools/similarity/wildfusion.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
extractor: Callable | None = None,
3535
calibration: Callable | None = None,
3636
transform: Callable | None = None,
37+
save_features: bool = False,
3738
):
3839
"""
3940
Args:
@@ -43,20 +44,29 @@ def __init__(
4344
calibration (callable, optional): A calibration model to refine similarity scores.
4445
transform (callable, optional): Image transformation function applied before feature
4546
extraction.
47+
save_features (bool, optional): Could prevent multiple computation of the same features.
48+
Increases memory usage, especially when using big datasets.
4649
"""
4750

4851
self.matcher = matcher
4952
self.calibration = calibration
5053
self.calibration_done = False
5154
self.extractor = extractor
5255
self.transform = transform
56+
self.save_features = save_features
57+
self.features = {}
5358

5459
def get_feature_dataset(self, dataset: ImageDataset) -> FeatureDataset:
5560
"""Apply transformations and extract features from the image dataset."""
5661

5762
if self.transform is not None:
5863
dataset.transform = self.transform
5964
if self.extractor is not None:
65+
if self.save_features:
66+
key = dataset.__repr__()
67+
if key not in self.features:
68+
self.features[key] = self.extractor(dataset)
69+
return self.features[key]
6070
return self.extractor(dataset)
6171
else:
6272
return dataset

0 commit comments

Comments
 (0)