@@ -34,6 +34,7 @@ def __init__(
3434 extractor : Callable | None = None ,
3535 calibration : Callable | None = None ,
3636 transform : Callable | None = None ,
37+ save_features : bool = False ,
3738 ):
3839 """
3940 Args:
@@ -43,20 +44,29 @@ def __init__(
4344 calibration (callable, optional): A calibration model to refine similarity scores.
4445 transform (callable, optional): Image transformation function applied before feature
4546 extraction.
47+ save_features (bool, optional): Could prevent multiple computation of the same features.
48+ Increases memory usage, especially when using big datasets.
4649 """
4750
4851 self .matcher = matcher
4952 self .calibration = calibration
5053 self .calibration_done = False
5154 self .extractor = extractor
5255 self .transform = transform
56+ self .save_features = save_features
57+ self .features = {}
5358
5459 def get_feature_dataset (self , dataset : ImageDataset ) -> FeatureDataset :
5560 """Apply transformations and extract features from the image dataset."""
5661
5762 if self .transform is not None :
5863 dataset .transform = self .transform
5964 if self .extractor is not None :
65+ if self .save_features :
66+ key = dataset .__repr__ ()
67+ if key not in self .features :
68+ self .features [key ] = self .extractor (dataset )
69+ return self .features [key ]
6070 return self .extractor (dataset )
6171 else :
6272 return dataset
0 commit comments